From 63d2352e0832ae02968a740188849df7e3956654 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Nov 2021 10:14:14 -0600 Subject: [PATCH 001/177] [TIR] Added BufferLoadNode::LegalizeDtype When modifying a BufferLoad object, the return dtype must also be updated. This exposes the legalization function, so that passes that use `BufferLoad::CopyOnWrite` to modify the buffer/indices don't need to repeat the logic to update the dtype returned. --- include/tvm/tir/expr.h | 8 ++++++++ src/tir/ir/expr.cc | 13 ++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index f6741112f269..4ba27fee70b0 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -610,6 +610,14 @@ class BufferLoadNode : public PrimExprNode { /*! \brief The indices location to be loaded. */ Array indices; + /*! \brief Set the dtype based on the buffer/indices + * + * Usually, this will be the same dtype as the buffer. This may + * have a different number of lanes than the buffer's dtype if index + * values have more than 1 lane. + */ + void LegalizeDtype(); + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index fbbd4a9522eb..afe24b73b80e 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1056,12 +1056,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); // BufferLoad +void BufferLoadNode::LegalizeDtype() { + int index_lanes = 1; + for (const auto& index : indices) { + index_lanes *= index.dtype().lanes(); + } + + int buffer_lanes = buffer->dtype.lanes(); + + this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); +} + BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); - node->dtype = buffer->dtype; node->buffer = std::move(buffer); node->indices = std::move(indices); node->span = std::move(span); + node->LegalizeDtype(); data_ = std::move(node); } From f698ce74b614d09ea55b7656953f19d07da2d541 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 Nov 2021 11:24:23 -0600 Subject: [PATCH 002/177] Replacing Store/Load in Stmt/Expr Visitor/Mutator --- src/tir/ir/expr_functor.cc | 12 +++--------- src/tir/ir/stmt_functor.cc | 18 +++--------------- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 4c5ea5bfd2d0..c8dc84695b4f 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -35,8 +35,7 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) { void ExprVisitor::VisitExpr_(const AnyNode* op) {} void ExprVisitor::VisitExpr_(const LoadNode* op) { - this->VisitExpr(op->index); - this->VisitExpr(op->predicate); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { @@ -127,13 +126,8 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && predicate.same_as(op->predicate)) { - return GetRef(op); - } else { - return Load(op->dtype, op->buffer_var, index, predicate); - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..ddecf282622a 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -59,9 +59,7 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) { } void StmtVisitor::VisitStmt_(const StoreNode* op) { - this->VisitExpr(op->value); - this->VisitExpr(op->index); - this->VisitExpr(op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { @@ -339,18 +337,8 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { } Stmt StmtMutator::VisitStmt_(const StoreNode* op) { - PrimExpr value = this->VisitExpr(op->value); - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->value = std::move(value); - n->index = std::move(index); - n->predicate = std::move(predicate); - return Stmt(n); - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { From 4331f36b7f2d5f8f1f7024bd34580e59843716a5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Nov 2021 10:36:45 -0600 Subject: [PATCH 003/177] Removing Store/Load from optimization passes - UpdatePointerStorageScope - UnrollLoop - ThreadSync - LinearAccessPatternFinder - StoragePlanRewriter - VectorTypeRewriter - VectorTypeAccessChecker - NarrowDataType - IRConvertSSA - CompactBufferRegion --- src/tir/transforms/compact_buffer_region.cc | 6 +- src/tir/transforms/ir_utils.cc | 61 +++- src/tir/transforms/narrow_datatype.cc | 56 +++- src/tir/transforms/storage_rewrite.cc | 280 ++++++++++++------ src/tir/transforms/thread_storage_sync.cc | 19 +- src/tir/transforms/unroll_loop.cc | 5 + .../update_pointer_storage_scope.cc | 56 +++- .../transforms/update_pointer_storage_scope.h | 8 + 8 files changed, 356 insertions(+), 135 deletions(-) diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 20ddd7f84a35..2970f81cccca 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -99,13 +99,11 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef(op)); } void VisitExpr_(const LoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - VisitBufferVar(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - StmtExprVisitor::VisitStmt_(op); - VisitBufferVar(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void VisitStmt_(const ForNode* op) final { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 4eb9cc5b1a90..7d8b1963c35b 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -111,26 +111,56 @@ class IRConvertSSA final : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } } + PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - const VarNode* v = op->buffer_var.get(); - if (scope_.count(v) && !scope_[v].empty()) { - return Load(op->dtype, scope_[v].back(), op->index, op->predicate); - } else { - return expr; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - const VarNode* v = op->buffer_var.get(); - if (scope_.count(v) && !scope_[v].empty()) { - return Store(scope_[v].back(), op->value, op->index, op->predicate); - } else { - return stmt; + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; } + + return node; } + + Buffer GetRemappedBuffer(Buffer buf) { + auto key = buf.get(); + auto buf_it = buf_remap_.find(key); + if (buf_it != buf_remap_.end()) { + return buf_it->second; + } + + auto var_it = scope_.find(buf->data.get()); + if (var_it != scope_.end() && !var_it->second.empty()) { + Var buffer_var = var_it->second.back(); + auto writer = buf.CopyOnWrite(); + writer->data = buffer_var; + } + + buf_remap_[key] = buf; + return buf; + } + Stmt VisitStmt_(const LetStmtNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { @@ -191,6 +221,7 @@ class IRConvertSSA final : public StmtExprMutator { private: std::unordered_map> scope_; std::unordered_set defined_; + std::unordered_map buf_remap_; }; Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index dc34626205a1..ff790b14a34d 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -205,12 +205,52 @@ class DataTypeRewriter : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - PrimExpr value = this->VisitExpr(op->value); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = GetRef(op); + + auto value = this->VisitExpr(op->value); + auto indices = VisitIndices(op->indices); + + if (!value.same_as(op->value) || !indices.same_as(op->indices)) { + auto writer = store.CopyOnWrite(); + writer->value = value; + writer->indices = indices; + } + + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = GetRef(op); + + auto indices = VisitIndices(op->indices); + + if (!indices.same_as(op->indices)) { + auto writer = load.CopyOnWrite(); + writer->indices = indices; + } + + return std::move(load); + } + + Array VisitIndices(Array indices) { is_index_ = true; - PrimExpr index = this->VisitExpr(op->index); + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + indices.MutateByApply(fmutate); + is_index_ = false; - Stmt s = Store(op->buffer_var, op->value, index, op->predicate); - return StmtExprMutator::VisitStmt_(s.as()); + + return indices; } Stmt VisitStmt_(const ForNode* op) final { @@ -263,14 +303,6 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - PrimExpr VisitExpr_(const LoadNode* op) final { - is_index_ = true; - PrimExpr index = this->VisitExpr(op->index); - is_index_ = false; - PrimExpr e = Load(op->dtype, op->buffer_var, index, op->predicate); - return StmtExprMutator::VisitExpr_(e.as()); - } - PrimExpr VisitExpr_(const IntImmNode* op) final { if (is_index_) { if (visitor_.vmap.find(op) != visitor_.vmap.end()) { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 409b7c262954..df61a6ce978c 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -89,12 +89,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr StmtExprVisitor::VisitStmt_(op); // Add write access. - const VarNode* buf = op->buffer_var.get(); + const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); @@ -107,6 +112,22 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } + + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + const VarNode* buf = op->buffer->data.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; + scope_[it->second.level].touched.push_back(buf); + } + } + void VisitStmt_(const EvaluateNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr @@ -118,16 +139,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } - void VisitExpr_(const LoadNode* op) final { - // Add write access. - StmtExprVisitor::VisitExpr_(op); - const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; - scope_[it->second.level].touched.push_back(buf); - } - } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); @@ -136,6 +148,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } } + void VisitExpr_(const VarNode* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); @@ -144,6 +157,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { scope_[it->second.level].touched.push_back(buf); } } + template void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); @@ -164,6 +178,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { ICHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } + void VisitStmt_(const AttrStmtNode* op) final { // Only record the outer most thread extent. if (op->attr_key == attr::thread_extent && !in_thread_env_) { @@ -178,6 +193,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } @@ -355,22 +371,62 @@ class StoragePlanRewriter : public StmtExprMutator { } return stmt; } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = alloc_map_.find(op->buffer_var.get()); - if (it == alloc_map_.end()) return stmt; - return Store(it->second->alloc_var, op->value, - RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } + PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = alloc_map_.find(op->buffer_var.get()); - if (it == alloc_map_.end()) return expr; - return Load(op->dtype, it->second->alloc_var, RemapIndex(op->dtype, op->index, it->second), - op->predicate); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + template + Node VisitBufferAccess(Node node) { + auto it = alloc_map_.find(node->buffer->data.get()); + if (it != alloc_map_.end()) { + Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); + + Array indices = node->indices; + indices.Set(indices.size() - 1, + RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); + + auto writer = node.CopyOnWrite(); + writer->buffer = buf; + writer->indices = indices; + } + return node; + } + + Buffer RemapBuffer(Buffer buf, Var new_backing_array) { + auto key = buf.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + ICHECK_EQ(it->second->data.get(), new_backing_array.get()) + << "Cannot remap buffer " << buf->name << " to use backing array " + << new_backing_array->name_hint << ", previously used backing array " + << it->second->data->name_hint; + return it->second; + } + + auto writer = buf.CopyOnWrite(); + writer->data = new_backing_array; + + buffer_remap_[key] = buf; + return buf; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); } + PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); if (it != alloc_map_.end()) { @@ -890,6 +946,8 @@ class StoragePlanRewriter : public StmtExprMutator { std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; + // The buffer objects being remapped + std::unordered_map buffer_remap_; // analyzer arith::Analyzer analyzer_; }; @@ -1006,20 +1064,29 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitExpr_(const LoadNode* op) final { - OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate); - StmtExprVisitor::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index, op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices); StmtExprVisitor::VisitStmt_(op); } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); PrimExpr index = op->args[2]; - OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes())); + OnArrayAccess(dtype, buffer, {index}); } StmtExprVisitor::VisitExpr_(op); } @@ -1093,8 +1160,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * * @param predicate The predicate used for the store/load. */ - void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const PrimExpr& index, - const PrimExpr& predicate) { + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices) { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; @@ -1110,6 +1176,11 @@ class VectorTypeAccessChecker : public StmtExprVisitor { var_info.element_dtype = value_dtype.element_of(); } + int index_lanes = 1; + for (const auto& index : indices) { + index_lanes *= index.dtype().lanes(); + } + DataType access_dtype = value_dtype; int lanes_used = var_info.element_dtype.lanes(); @@ -1120,8 +1191,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // necessary because the C-based codegens do not yet support vectorized // pointer types (e.g. float16x4*). Once they do, this if statement should // instead be replaced by the below ICHECK_EQ. - if (index.dtype().lanes() * var_info.element_dtype.lanes() != value_dtype.lanes()) { - ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes()); + if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) { + ICHECK_EQ(index_lanes, value_dtype.lanes()); lanes_used = 1; var_info.element_dtype = var_info.element_dtype.with_lanes(1); } @@ -1130,19 +1201,19 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 // for discussion. - // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes()) + // ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes()) // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " - // << index.dtype().lanes() << " indices into an array whose elements have " + // << index_lanes << " indices into an array whose elements have " // << var_info.element_dtype.lanes() << " lanes. " - // << "Expected output with " << index.dtype().lanes() * var_info.element_dtype.lanes() + // << "Expected output with " << index_lanes * var_info.element_dtype.lanes() // << " lanes."; // If the index is a RampNode with stride of 1 and offset // divisible by the number of number of lanes, and the predicate // does not apply any masking, then this array access could be // vectorized. - const RampNode* ramp_index = index.as(); - if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) { + const RampNode* ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { arith::ModularSet me = analyzer_.modular_set(ramp_index->base); if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { lanes_used = ramp_index->lanes; @@ -1244,55 +1315,92 @@ class VectorTypeRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + template + Node VisitBufferAccess(Node node) { if (!rewrite_indices_) { - return expr; + return node; } - auto it = rewrite_map_.find(op->buffer_var.get()); + auto it = rewrite_map_.find(node->buffer->data.get()); if (it == rewrite_map_.end()) { - return expr; + return node; } const auto& info = it->second; - DataType out_dtype_base = info.new_element_dtype.element_of(); + Array indices = node->indices; - const RampNode* ramp_index = op->index.as(); + const RampNode* ramp_index = indices[indices.size() - 1].as(); if (ramp_index && is_one(ramp_index->stride)) { PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index, - const_true(new_index.dtype().lanes()), op->span); - } else { - return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate); + if (ramp_index->lanes != info.factor()) { + new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(), + ramp_index->span); + } + + indices.Set(indices.size() - 1, new_index); } + + auto writer = node.CopyOnWrite(); + writer->buffer = RemapBuffer(node->buffer); + writer->indices = indices; + + return node; } - Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto modified = VisitBufferAccess(node); - if (!rewrite_indices_) { - return stmt; + // Not needed for BufferStoreNode, so we can't just call + // LegalizeDtype() in VisitBufferAccess. + if (node.same_as(modified)) { + return std::move(node); + } else { + auto writer = modified.CopyOnWrite(); + writer->LegalizeDtype(); + return std::move(modified); } + } - auto it = rewrite_map_.find(op->buffer_var.get()); - if (it == rewrite_map_.end()) { - return stmt; + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + Buffer RemapBuffer(Buffer buf) { + auto cache_key = buf.get(); + + auto cache_it = buffer_map_.find(cache_key); + if (cache_it != buffer_map_.end()) { + return cache_it->second; } - const auto& info = it->second; - const RampNode* ramp_index = op->index.as(); - if (ramp_index && is_one(ramp_index->stride)) { - PrimExpr new_index = - ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), - op->span); - } else { - return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span); + auto info_it = rewrite_map_.find(buf->data.get()); + if (info_it != rewrite_map_.end()) { + auto& info = info_it->second; + + Array shape = buf->shape; + PrimExpr last_dim = shape[shape.size() - 1]; + shape.Set(shape.size() - 1, last_dim / make_const(last_dim.dtype(), info.factor())); + + auto writer = buf.CopyOnWrite(); + writer->data = info.new_buffer_var; + writer->dtype = info.new_element_dtype; + writer->shape = shape; } + + buffer_map_[cache_key] = buf; + return buf; } PrimExpr VisitExpr_(const CallNode* op) final { @@ -1316,9 +1424,9 @@ class VectorTypeRewriter : public StmtExprMutator { PrimExpr flag = op->args[4]; PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); - PrimExpr factor = make_const(extent.dtype(), info.new_element_dtype.lanes()); - extent = extent / factor; - index = index / factor; + int factor = info.factor(); + extent = extent / make_const(extent.dtype(), factor); + index = index / make_const(index.dtype(), factor); Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); @@ -1340,11 +1448,9 @@ class VectorTypeRewriter : public StmtExprMutator { Var new_buffer_var = info.new_buffer_var; - int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - Array extents = op->extents; - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); + PrimExpr last_extent = extents[extents.size() - 1]; + extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor())); return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); } @@ -1355,7 +1461,7 @@ class VectorTypeRewriter : public StmtExprMutator { * * @param func A pointer to the PrimFunc being modified. */ - void Finalize(PrimFunc* func_ptr) const { + void Finalize(PrimFunc* func_ptr) { ICHECK(func_ptr) << "Finalize expects a non-null pointer"; auto& func = *func_ptr; auto* n = func.CopyOnWrite(); @@ -1381,29 +1487,15 @@ class VectorTypeRewriter : public StmtExprMutator { } n->params = new_params; - // Remap the Buffer objects in so that the buffers use the new buffer variables + // Remap the Buffer objects in PrimFunc::buffer_map so that the + // buffers use the new buffer variables Map new_buffer_map; for (const auto& pair : n->buffer_map) { Var key = pair.first; Buffer old_buffer = pair.second; Var old_var = old_buffer->data; - - auto it = rewrite_map_.find(old_var.get()); - if (it == rewrite_map_.end()) { - new_buffer_map.Set(key, old_buffer); - } else { - auto& info = it->second; - int factor = info.new_element_dtype.lanes() / info.old_element_dtype.lanes(); - ICHECK_EQ(factor * info.new_element_dtype.lanes(), info.old_element_dtype.lanes()); - - auto* buffer_cow = old_buffer.CopyOnWrite(); - buffer_cow->data = info.new_buffer_var; - buffer_cow->dtype = info.new_element_dtype; - size_t ndim = buffer_cow->shape.size(); - const auto& last_dim = buffer_cow->shape[ndim - 1]; - buffer_cow->shape.Set(ndim - 1, last_dim / make_const(last_dim.dtype(), factor)); - new_buffer_map.Set(key, old_buffer); - } + Buffer new_buffer = RemapBuffer(old_buffer); + new_buffer_map.Set(key, new_buffer); } n->buffer_map = new_buffer_map; } @@ -1414,10 +1506,18 @@ class VectorTypeRewriter : public StmtExprMutator { Var new_buffer_var; DataType old_element_dtype; DataType new_element_dtype; + + int factor() const { + int old_lanes = old_element_dtype.lanes(); + int new_lanes = new_element_dtype.lanes(); + ICHECK_EQ(new_lanes % old_lanes, 0); + return new_lanes / old_lanes; + } }; bool rewrite_indices_{true}; std::unordered_map rewrite_map_; + std::unordered_map buffer_map_; }; // Rewrite allocates, pointer parameters, and buffer map into vectorized versions diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 35e4563b8f58..b2b27c9dd618 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -222,16 +222,25 @@ class ThreadSyncInserter : public StmtExprMutator { } } PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[op->buffer_var].read_count; + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].read_count; } return StmtExprMutator::VisitExpr_(op); } - Stmt VisitStmt_(const StoreNode* op) final { + Stmt VisitStmt_(const BufferStoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[op->buffer_var].write_count; + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].write_count; } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index c6e0b5c5f41e..e1d0688ab537 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -134,6 +134,11 @@ class LoopUnroller : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { ++step_count_; return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 4143577a0b17..69db85eda2df 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -58,22 +58,60 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { return it->second; } -PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), - StmtExprMutator::VisitExpr(op->predicate)); -} - Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), StmtExprMutator::VisitStmt(op->body)); } +template +Node UpdatePointerStorageScope::UpdateBufferAccess(Node node) { + auto new_buffer = GetUpdatedBuffer(node->buffer); + if (!new_buffer.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buffer; + } + return node; +} + +Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) { + // Use the cached buffer, if it exists. + auto key = buf.get(); + auto it = new_buffer_remap_.find(key); + if (it != new_buffer_remap_.end()) { + return it->second; + } + + // Update the buffer's var, if needed. + auto remapped = Downcast(StmtExprMutator::VisitExpr(buf->data)); + if (!remapped.same_as(buf->data)) { + auto writer = buf.CopyOnWrite(); + writer->data = remapped; + } + + // Update the cache and return + new_buffer_remap_[key] = buf; + return buf; +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return UpdateBufferAccess(node); +} + Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), - StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const BufferStoreNode* op) { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return UpdateBufferAccess(node); } } // namespace tir diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index f310194a4a51..d5e492e83389 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -40,11 +40,19 @@ class UpdatePointerStorageScope : public StmtExprMutator { virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); + virtual PrimExpr VisitExpr_(const BufferLoadNode*); virtual Stmt VisitStmt_(const AllocateNode*); virtual Stmt VisitStmt_(const StoreNode*); + virtual Stmt VisitStmt_(const BufferStoreNode*); private: + template + Node UpdateBufferAccess(Node node); + + Buffer GetUpdatedBuffer(Buffer buf); + std::unordered_map new_var_remap_; + std::unordered_map new_buffer_remap_; }; } // namespace tir From ebdc8c77c2c2b741e2a349e6986d3943c7a0be2f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 Nov 2021 11:03:30 -0600 Subject: [PATCH 004/177] Removing Store/Load from examples - ConvertAddToSubtract --- .../backend/contrib/example_target_hooks/relay_to_tir.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 6794594b5ba4..1317ceb7a174 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -52,8 +52,8 @@ class ConvertAddToSubtract : public MixedModeMutator { } private: - tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { - return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true()); + tir::BufferLoad LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { + return tir::BufferLoad(buffer, {index}); } void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { @@ -71,7 +71,7 @@ class ConvertAddToSubtract : public MixedModeMutator { te::Var index("index", DataType::Int(32)); tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); - tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true()); + tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); Map buffer_map = { From bba004ab775c9dd4517c1cd565eadceac1efc585 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Nov 2021 14:24:05 -0600 Subject: [PATCH 005/177] Replacing Store/Load in StorageFlatten Now, outputs BufferLoad/BufferStore with a flattened buffer object. temp commit, replacing Store/Load, BufferBindUnwrapper temp commit, replacing Store/Load, StorageFlattener --- src/tir/transforms/storage_flatten.cc | 303 ++++++++++++-------------- 1 file changed, 145 insertions(+), 158 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index ccc660509ca1..0293519c9945 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -37,6 +37,7 @@ #include #include +#include #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" @@ -163,43 +164,49 @@ class BufferShapeLegalize : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } - auto it = buf_map_.find(op->buffer); + template + Node VisitBufferAccess(Node node) { + auto it = buf_map_.find(node->buffer); if (it != buf_map_.end()) { const BufferEntry& entry = it->second; - ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer"; + ICHECK(entry.in_scope) << "Cannot access an out-of-scope buffer"; - BufferStore updated = GetRef(op); - auto write_ptr = updated.CopyOnWrite(); - write_ptr->indices = update_indices(op->indices, entry.index_offsets); - write_ptr->buffer = entry.remap_to; - stmt = updated; - } + Array indices = node->indices; + if (entry.index_offsets.size()) { + ICHECK_GE(entry.index_offsets.size(), indices.size()) + << "Cannot bind buffer to a shape of lower dimension."; - return stmt; - } + Array new_indices; - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - ICHECK(op); + // Pad leading indices with zero, matching the "fuzzy_match" + // behavior from ArgBinder::BindBuffer. + size_t diff = entry.index_offsets.size() - indices.size(); + for (size_t i = 0; i < diff; i++) { + new_indices.push_back(0); + } - auto it = buf_map_.find(op->buffer); - if (it != buf_map_.end()) { - const BufferEntry& entry = it->second; - ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer"; + // Offset indices used to access buffers of a reduced size. + for (size_t i = 0; i < indices.size(); i++) { + PrimExpr offset = entry.index_offsets[i + diff]; + new_indices.push_back(indices[i] - offset); + } + indices = new_indices; + } - BufferLoad updated = GetRef(op); - auto write_ptr = updated.CopyOnWrite(); - write_ptr->indices = update_indices(op->indices, entry.index_offsets); + auto write_ptr = node.CopyOnWrite(); + write_ptr->indices = indices; write_ptr->buffer = entry.remap_to; - expr = updated; } - - return expr; + return node; } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -341,36 +348,6 @@ class BufferShapeLegalize : public StmtExprMutator { return stmt; } - Array update_indices(const Array& indices, const Array& offsets) { - // offsets come from BufferRealizeNode::bounds, which is allowed - // to be empty to indicate realization of the full shape of the - // buffer. In that case, the indices do not need to be modified, - // but may need to be extended with leading zeroes. - if (offsets.size() == 0) { - return indices; - } - - ICHECK_GE(offsets.size(), indices.size()) - << "Cannot bind buffer to a shape of lower dimension."; - - Array new_indices; - - // Pad leading indices with zero, matching the "fuzzy_match" - // behavior from ArgBinder::BindBuffer. - size_t diff = offsets.size() - indices.size(); - for (size_t i = 0; i < diff; i++) { - new_indices.push_back(0); - } - - // Offset indices used to access buffers of a reduced size. - for (size_t i = 0; i < indices.size(); i++) { - PrimExpr offset = offsets[i + diff]; - new_indices.push_back(indices[i] - offset); - } - - return new_indices; - } - std::unordered_map var_remap_; std::unordered_set extern_buffers_; @@ -516,6 +493,14 @@ class BufferStrideLegalize : public StmtExprMutator { } } + // AllocateNodes may be present from tvm.tir.ir_builder. This can + // be simplified in the future by having AllocateNode hold a buffer, + // rather than a buffer_var. + Stmt VisitStmt_(const AllocateNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { Buffer key = op->buffer; Buffer with_strides = WithStrides(op->buffer); @@ -536,28 +521,37 @@ class BufferStrideLegalize : public StmtExprMutator { return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span); } - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - - auto it = buf_map_.find(op->buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope"; + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } - return BufferLoad(e.remap_to, op->indices, op->span); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); } - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + template + Node VisitBufferAccess(Node node) { + auto alloc_key = node->buffer->data.get(); + if (allocate_node_var_.count(alloc_key)) { + BufferEntry entry; + entry.remap_to = WithStrides(node->buffer); + entry.in_scope = true; + entry.is_external = false; + buf_map_[node->buffer] = entry; + allocate_node_var_.erase(alloc_key); + } - auto it = buf_map_.find(op->buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; + auto it = buf_map_.find(node->buffer); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << node->buffer; const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot access a buffer " << node->buffer->name << ", out of scope"; + + auto writer = node.CopyOnWrite(); + writer->buffer = e.remap_to; - return BufferStore(e.remap_to, op->value, op->indices, op->span); + return node; } private: @@ -579,6 +573,10 @@ class BufferStrideLegalize : public StmtExprMutator { std::unordered_map buf_map_; + // Set of vars that have occurred in an AllocateNode, but haven't + // yet occurred in a BufferLoad/BufferStore. + std::unordered_set allocate_node_var_; + IRVisitorWithAnalyzer* bound_analyzer_; }; @@ -778,39 +776,13 @@ class BufferBindUnwrapper : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - // TODO(Lunderberg): Change from warning to error once all mixed - // use of physical/logical layouts is removed. - DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), " - << "but is accessed as a pointer (StoreNode)."; - - ICHECK(it->second.as()); - Var new_buf_var = Downcast(it->second); - return Store(new_buf_var, op->value, op->index, op->predicate); - } else { - return stmt; - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - // TODO(Lunderberg): Change from warning to error once all mixed - // use of physical/logical layouts is removed. - DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), " - << "but is accessed as a pointer (LoadNode)."; - - ICHECK(it->second.as()); - Var new_buf_var = Downcast(it->second); - return Load(op->dtype, new_buf_var, op->index, op->predicate); - } else { - return expr; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -868,14 +840,19 @@ class BufferBindUnwrapper : public StmtExprMutator { return out; } + // AllocateNodes may be present from tvm.tir.ir_builder. This can + // be simplified in the future by having AllocateNode hold a buffer, + // rather than a buffer_var. + Stmt VisitStmt_(const AllocateNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - auto it = buf_map_.find(op->buffer.get()); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read from buffer " << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { return BufferLoad(e.remap->target, @@ -889,10 +866,7 @@ class BufferBindUnwrapper : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - auto it = buf_map_.find(op->buffer.get()); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to buffer" << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { return BufferStore(e.remap->target, op->value, @@ -933,10 +907,7 @@ class BufferBindUnwrapper : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - const auto& key = op->buffer.get(); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - const BufferEntry& e = it->second; + const BufferEntry& e = GetBufferEntry(op->buffer); ICHECK(e.in_scope) << "Read a buffer that is already out of scope"; ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) @@ -1066,11 +1037,30 @@ class BufferBindUnwrapper : public StmtExprMutator { std::unique_ptr remap{nullptr}; }; + const BufferEntry& GetBufferEntry(Buffer buffer) { + auto alloc_key = buffer->data.get(); + if (allocate_node_var_.count(alloc_key)) { + BufferEntry entry; + entry.buffer = buffer; + buf_map_[buffer.get()] = std::move(entry); + allocate_node_var_.erase(alloc_key); + } + + auto it = buf_map_.find(buffer.get()); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; + return it->second; + } + // The buffer assignment map // Variable remap std::unordered_map var_remap_; // Buffer map std::unordered_map buf_map_; + // Set of vars that have occurred in an AllocateNode, but haven't + // yet occurred in a BufferLoad/BufferStore. + std::unordered_set allocate_node_var_; // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the // analyzer from it. However IRVisitorWithAnalyzer* bound_analyzer_; @@ -1105,16 +1095,13 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - ICHECK(it->second.as()); - Var buf_var = Downcast(it->second); - return Store(buf_var, op->value, op->index, op->predicate); - } else { - return stmt; - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -1130,9 +1117,8 @@ class StorageFlattener : public StmtExprMutator { if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); - auto it = buf_map_.find(buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; - body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); + const auto& entry = GetBufferEntry(buffer); + body = AttrStmt(entry.flattened_buffer->data, op->attr_key, op->value, std::move(body)); return body; } return StmtExprMutator::VisitStmt_(op); @@ -1143,13 +1129,7 @@ class StorageFlattener : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - const auto& key = op->buffer; - - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to " << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); Stmt body = e.buffer.vstore(op->indices, op->value); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { @@ -1165,6 +1145,14 @@ class StorageFlattener : public StmtExprMutator { return body; } + // AllocateNodes may be present from tvm.tir.ir_builder. This can + // be simplified in the future by having AllocateNode hold a buffer, + // rather than a buffer_var. + Stmt VisitStmt_(const AllocateNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { const auto& key = op->buffer; @@ -1244,19 +1232,6 @@ class StorageFlattener : public StmtExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { - ICHECK(it->second.as()); - Var buf_var = Downcast(it->second); - return Load(op->dtype, buf_var, op->index, op->predicate); - } else { - return expr; - } - } - PrimExpr VisitExpr_(const VarNode* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { @@ -1270,12 +1245,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - const auto& key = op->buffer; - - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read to " << op->buffer << ", out of scope."; + const BufferEntry& e = GetBufferEntry(op->buffer); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); @@ -1288,10 +1258,7 @@ class StorageFlattener : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - const auto& key = op->buffer; - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; - const BufferEntry& e = it->second; + const BufferEntry& e = GetBufferEntry(op->buffer); ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope."; ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) @@ -1392,9 +1359,29 @@ class StorageFlattener : public StmtExprMutator { return bound; } + const BufferEntry& GetBufferEntry(Buffer buffer) { + auto alloc_key = buffer->data.get(); + if (allocate_node_var_.count(alloc_key)) { + BufferEntry entry; + entry.buffer = buffer; + entry.flattened_buffer = buffer.GetFlattenedBuffer(); + buf_map_[buffer] = std::move(entry); + allocate_node_var_.erase(alloc_key); + } + + auto it = buf_map_.find(buffer); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; + return it->second; + } + // The buffer assignment map // Variable remap std::unordered_map var_remap_; + // Set of vars that have occurred in an AllocateNode, but haven't + // yet occurred in a BufferLoad/BufferStore. + std::unordered_set allocate_node_var_; // Buffer map std::unordered_map buf_map_; // Collects shapes. From 7c33db5e4069a3fe30906af4f262da502cd392a2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Nov 2021 16:04:25 -0600 Subject: [PATCH 006/177] Replacing Store/Load in utility passes. - StmtSimplifier - IRSubstitute - BaseInliner - FeatureVisitor --- src/autotvm/feature_visitor.cc | 10 +-- src/autotvm/feature_visitor.h | 4 +- src/tir/ir/stmt_functor.cc | 64 ++++++++++++++++---- src/tir/schedule/primitive/compute_inline.cc | 12 ++-- src/tir/transforms/simplify.cc | 34 ++++++++--- 5 files changed, 92 insertions(+), 32 deletions(-) diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 59cac9cc9827..17a05f024621 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -97,14 +97,16 @@ void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { } // memory access -void FeatureVisitor::VisitExpr_(const LoadNode* op) { - EnterMem_(op->buffer_var, op->index); +void FeatureVisitor::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers"; + EnterMem_(op->buffer->data, op->indices[0]); StmtExprVisitor::VisitExpr_(op); ExitMem_(); } -void FeatureVisitor::VisitStmt_(const StoreNode* op) { - EnterMem_(op->buffer_var, op->index); +void FeatureVisitor::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers"; + EnterMem_(op->buffer->data, op->indices[0]); StmtExprVisitor::VisitStmt_(op); ExitMem_(); } diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 8180839b0668..3d34882c77db 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -66,8 +66,8 @@ class FeatureVisitor : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final; // memory access - void VisitExpr_(const LoadNode* op) final; - void VisitStmt_(const StoreNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ddecf282622a..103ec1e42a20 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -633,23 +633,51 @@ class IRSubstitute : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - if (auto mapped_var = vmap_(op->buffer_var)) { - return Load(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); - } else { - return ret; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - if (auto mapped_var = vmap_(op->buffer_var)) { - return Store(Downcast(mapped_var.value()), op->value, op->index, op->predicate); - } else { - return ret; + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; } + + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; + } + + if (auto mapped_var = vmap_(buf->data)) { + auto writer = buf.CopyOnWrite(); + writer->data = Downcast(mapped_var); + } + + buf_remap_[key] = buf; + return buf; } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -665,7 +693,17 @@ class IRSubstitute : public StmtExprMutator { } private: + // Caller provided function that defines the variables to be remapped. std::function(const Var&)> vmap_; + + /* \brief Generated map to track buffers being remapped. + * + * If a `Var BufferNode::data` is remapped, then all buffers + * containing that data pointer should also be remapped. This map + * is used to track buffer modifications, and ensure all instances + * of a buffer are replaced by the same modified buffer object. + */ + std::unordered_map buf_remap_; }; Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 9a9860b42bc6..6dccfc311cd4 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -200,14 +200,14 @@ class BaseInliner : public StmtExprMutator { return StmtExprMutator::VisitExpr_(var); } - PrimExpr VisitExpr_(const LoadNode* load) final { - CheckOpaqueAccess(load->buffer_var.get()); - return StmtExprMutator::VisitExpr_(load); + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } - Stmt VisitStmt_(const StoreNode* store) final { - CheckOpaqueAccess(store->buffer_var.get()); - return StmtExprMutator::VisitStmt_(store); + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } Stmt VisitStmt_(const ForNode* loop) final { diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index df8816c8f693..7d4fac8d7b2d 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -82,17 +82,37 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } } - // eliminate useless stores Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = Parent::VisitStmt_(op); - op = stmt.as(); - if (const LoadNode* load = op->value.as()) { - if (load->buffer_var.same_as(op->buffer_var) && - tir::ExprDeepEqual()(load->index, op->index)) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + // eliminate useless stores + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(Parent::VisitStmt_(op)); + if (const BufferLoadNode* load = op->value.as()) { + if (load->buffer->data.same_as(op->buffer->data) && + ArrayDeepEqual(load->indices, op->indices) && + tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) && + ArrayDeepEqual(load->buffer->shape, op->buffer->shape) && + ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) { return Evaluate(0); } } - return GetRef(op); + return std::move(store); + } + + private: + bool ArrayDeepEqual(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { + return false; + } + } + return true; } }; From da05133a99da9f1c704ea2bbd2a7919433dd34b2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Nov 2021 15:46:01 -0600 Subject: [PATCH 007/177] Replacing Store/Load in analysis functions - StorageAccessVisitor - VarTouchedAnalysis - MemoryAccessVerifier - InplaceOpVerifier - GPUCodeVerifier - VarTouchVisitor - LCADetector - BlockReadWriteDetector - InstrumentBoundCheckers --- .../analysis/block_access_region_detector.cc | 6 +- .../analysis/buffer_access_lca_detector.cc | 6 +- src/tir/analysis/var_touch.cc | 14 +- src/tir/analysis/verify_gpu_code.cc | 12 +- src/tir/analysis/verify_memory.cc | 14 +- src/tir/transforms/bound_checker.cc | 166 +++++++++++------- src/tir/transforms/coproc_sync.cc | 3 +- src/tir/transforms/inject_virtual_thread.cc | 11 +- src/tir/transforms/storage_access.cc | 45 +++-- src/tir/transforms/storage_access.h | 9 +- src/tir/transforms/storage_flatten.cc | 4 +- src/tir/transforms/storage_rewrite.cc | 34 +++- src/tir/transforms/thread_storage_sync.cc | 52 ++++-- tests/cpp/tir_analysis_side_effect.cc | 5 +- 14 files changed, 262 insertions(+), 119 deletions(-) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 3038eca8d338..03e02064f798 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -141,8 +141,7 @@ Array BlockReadWriteDetector::CollectOpaques() { void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { - UpdateOpaque(op->buffer_var); - ExprVisitor::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { @@ -194,8 +193,7 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { } void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { - UpdateOpaque(op->buffer_var); - StmtVisitor::VisitStmt_(op); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index e680d689735d..c004c86fe77a 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -120,13 +120,11 @@ class LCADetector : public StmtExprVisitor { // Explict to visit buffer data in Load and Store node. void VisitExpr_(const LoadNode* op) final { - ExprVisitor::VisitExpr_(op); - VisitBufferVar(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - StmtVisitor::VisitStmt_(op); - VisitBufferVar(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } void VisitBufferVar(const VarNode* op) { diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index c4acd2b74aad..f92afc4d15a1 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -44,13 +44,21 @@ class VarTouchVisitor : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { Handle(op); } + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + void VisitStmt_(const StoreNode* op) final { - Handle(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { + Handle(op->buffer->data.get()); StmtVisitor::VisitStmt_(op); } - void VisitExpr_(const LoadNode* op) final { - Handle(op->buffer_var.get()); + void VisitExpr_(const BufferLoadNode* op) final { + Handle(op->buffer->data.get()); ExprVisitor::VisitExpr_(op); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index dc1ed1c193e8..112979f17f45 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -184,7 +184,15 @@ class GPUCodeVerifier : public StmtExprVisitor { StmtVisitor::VisitStmt_(op); } - void VisitExpr_(const LoadNode* op) { + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) { if (op->dtype.lanes() > 1) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; @@ -197,7 +205,7 @@ class GPUCodeVerifier : public StmtExprVisitor { ExprVisitor::VisitExpr_(op); } - void VisitStmt_(const StoreNode* op) { + void VisitStmt_(const BufferStoreNode* op) { if (op->value->dtype.lanes() > 1) { if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index b6c41b958c31..6ee30e04704a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -89,12 +89,20 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } void VisitExpr_(const LoadNode* op) final { - HandleLoadStoreToVariable(op->buffer_var); - return StmtExprVisitor::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; } void VisitStmt_(const StoreNode* op) final { - HandleLoadStoreToVariable(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + HandleLoadStoreToVariable(op->buffer->data); + return StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + HandleLoadStoreToVariable(op->buffer->data); return StmtExprVisitor::VisitStmt_(op); } //@} diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 3b6af0644fc9..85aac3cee855 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -37,25 +37,30 @@ namespace tvm { namespace tir { +// TODO(Lunderberg): Move this pass to be before +// StorageFlatten/FlattenBuffer. That will simplify this pass, +// because it can check directly against the buffer limits. class BoundCollector : public StmtVisitor { public: BoundCollector() {} void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tir::attr::buffer_bound) { - if (const VarNode* key = op->node.as()) { - mem_to_shape[key] = op->value; + const VarNode* key = op->node.as(); + const CallNode* container = op->value.as(); + if (key && container) { + mem_to_shape[key] = container->args; } } StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape; + std::unordered_map> mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker(const std::unordered_map& mem_to_shape) + explicit BoundChecker(const std::unordered_map>& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -73,21 +78,31 @@ class BoundChecker : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { store_scope_bound_collector_.clear(); process_store_ = true; unsafe_rewritten_ = false; StmtExprMutator::VisitStmt_(op); process_store_ = false; - if (CanInstrument(op->index, op->buffer_var)) { - Collect(op->index, op->buffer_var); + if (CanInstrument(op->indices, op->buffer->data)) { + Collect(op->indices, op->buffer->data); } // The collector should has at least one item. if (store_scope_bound_collector_.size()) { PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = Evaluate(1); - Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate); + Stmt then_case = GetRef(op); Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); Stmt body = IfThenElse(condition, then_case, else_case); return body; @@ -96,9 +111,9 @@ class BoundChecker : public StmtExprMutator { return GetRef(op); } - PrimExpr VisitExpr_(const LoadNode* op) final { - if (CanInstrument(op->index, op->buffer_var)) { - Collect(op->index, op->buffer_var); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + if (CanInstrument(op->indices, op->buffer->data)) { + Collect(op->indices, op->buffer->data); } return StmtExprMutator::VisitExpr_(op); } @@ -108,79 +123,106 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { + void Update(const Var& buffer_var, Array new_shape, const DataType& type) { // Sanity check at first. - if (!new_shape.size()) { + if (!ShapeIsValid(new_shape)) { return; } - for (size_t i = 0; i < new_shape.size(); ++i) { - if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() || - is_negative_const(new_shape[i])) { - return; + new_shape.MutateByApply([&](const PrimExpr& dim) { + // Cast to uint64 to avoid potential overflow. + return make_const(DataType::UInt(64), type.lanes()) * dim; + }); + mem_to_shape_[buffer_var.get()] = new_shape; + } + + bool ShapeIsValid(const Array& shape) const { + if (!shape.defined()) { + return false; + } + for (const auto& dim : shape) { + if (!IsValidScalar(dim) || is_negative_const(dim)) { + return false; } } - // Scalarize the shape. - PrimExpr shape = - Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0])); - for (size_t i = 1; i < new_shape.size(); ++i) { - // Cast to unsigned to avoid integer overlow at frist. - shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()), - Cast(DataType::UInt(64), new_shape[i]))); - } - mem_to_shape_[buffer_var.get()] = shape; + return true; } - bool IndexIsValid(const PrimExpr& index) const { - if (!index.defined()) { + bool IndicesAreValid(const Array& indices) const { + if (!indices.defined()) { return false; } - if (const RampNode* ramp_index = index.as()) { - return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() && - ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() && - (ramp_index->lanes > 0); + for (const auto& index : indices) { + if (!index.defined()) { + return false; + } + + if (const RampNode* ramp_index = index.as()) { + if (!IsValidScalar(ramp_index->base)) { + return false; + } + if (!IsValidScalar(ramp_index->stride)) { + return false; + } + if (ramp_index->lanes <= 0) { + return false; + } + } } return true; } - bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const { - return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && - !unsafe_rewritten_; + bool IsValidScalar(const PrimExpr& expr) const { + return expr.defined() && expr.dtype().is_scalar(); + } + + bool CanInstrument(const Array& indices, const Var& buffer_var) const { + return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && + IndicesAreValid(indices) && !unsafe_rewritten_; } - void Collect(PrimExpr index, Var buffer_var) { - store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()])); + void Collect(Array indices, Var buffer_var) { + store_scope_bound_collector_.push_back( + std::make_pair(indices, mem_to_shape_[buffer_var.get()])); } PrimExpr MakeCondition() { PrimExpr condition; - for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) { - std::pair buffer_to_mem = store_scope_bound_collector_[i]; - PrimExpr index = buffer_to_mem.first; - PrimExpr upper_bound = buffer_to_mem.second; - - if (const RampNode* ramp_index = index.as()) { - // In case index is base + stride * i. - // Non inclusive range. - index = Add(ramp_index->base, Mul(ramp_index->stride, make_const(ramp_index->stride.dtype(), - ramp_index->lanes - 1))); + for (const auto& pair : store_scope_bound_collector_) { + Array indices = pair.first; + Array shape = pair.second; + + ICHECK_EQ(indices.size(), shape.size()) + << "Mismatch between dimension of physical shape and physical indices"; + + for (size_t i = 0; i < indices.size(); i++) { + PrimExpr index = indices[i]; + PrimExpr upper_bound = shape[i]; + + if (const RampNode* ramp_index = index.as()) { + // In case index is base + stride * i. + // Non inclusive range. + index = Add(ramp_index->base, + Mul(ramp_index->stride, + make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); + } + + // Try to simplify index and bound. + index = analyzer_.Simplify(index); + upper_bound = analyzer_.Simplify(upper_bound); + + // Cast to the same type - signed, to be able to check lower bound. + index = Cast(DataType::Int(64), index); + upper_bound = Cast(DataType::Int(64), upper_bound); + + // Looks like a lower bound should always be zero after normalization. + PrimExpr lower_bound = make_zero(DataType::Int(64)); + + PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); + condition = condition.defined() ? And(condition, current_condition) : current_condition; } - - // Try to simplify index and bound. - index = analyzer_.Simplify(index); - upper_bound = analyzer_.Simplify(upper_bound); - - // Cast to the same type - signed, to be able to check lower bound. - index = Cast(DataType::Int(64), index); - upper_bound = Cast(DataType::Int(64), upper_bound); - - // Looks like a lower bound should always be zero after normalization. - PrimExpr lower_bound = make_zero(DataType::Int(64)); - - PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); - condition = !i ? current_condition : And(condition, current_condition); } return condition; } @@ -190,11 +232,11 @@ class BoundChecker : public StmtExprMutator { // Whether we face tvm_if_then_else intrinsic. bool unsafe_rewritten_{false}; // Pool which collects the pair of index and shape for specific store/load. - std::vector> store_scope_bound_collector_; + std::vector, Array>> store_scope_bound_collector_; // Error message. const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape_; + std::unordered_map> mem_to_shape_; // internal analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 7a6d2d37c376..eb1ade173a38 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -325,7 +325,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { Array wset; for (const AccessEntry& acc : wvec) { ICHECK(acc.dtype == wvec[0].dtype); - wset.push_back(acc.touched); + ICHECK_EQ(acc.touched.size(), 1) << "CoProcBarrierDetector expects flat memory"; + wset.push_back(acc.touched[0]); } Range none; Range r = arith::Union(wset).CoverRange(none); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 4964bec0334e..59391554948e 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -101,11 +101,18 @@ class VarTouchedAnalysis : public StmtVisitor { Record(op->var.get(), tc); this->VisitStmt(op->body); } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { ExprTouched tc(touched_var_, false); tc(op->value); - tc(op->index); - Record(op->buffer_var.get(), tc); + for (const auto& index : op->indices) { + tc(index); + } + Record(op->buffer->data.get(), tc); } void VisitStmt_(const ForNode* op) final { ExprTouched tc(touched_var_, false); diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 0567c8613fcd..025233df56a7 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -34,15 +34,25 @@ namespace tvm { namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { - const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(op->buffer_var); - if (Enabled(buf, scope)) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { + Var buf = op->buffer->data; + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; e.threads = env_threads(); - e.buffer = op->buffer_var; + e.buffer = buf; e.dtype = op->dtype.element_of(); - e.touched = arith::IntSet::Vector(op->index); + for (const auto& index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } e.type = kRead; e.scope = scope; curr_stmt_.access.emplace_back(std::move(e)); @@ -51,18 +61,21 @@ void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { StmtExprVisitor::VisitExpr_(op); } -void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { +void StorageAccessVisitor::VisitStmt_(const BufferStoreNode* op) { allow_append_ = true; ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; - const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(op->buffer_var); - if (Enabled(buf, scope)) { + + Var buf = op->buffer->data; + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { AccessEntry e; e.threads = env_threads(); - e.buffer = op->buffer_var; + e.buffer = buf; e.dtype = op->value.dtype().element_of(); - e.touched = arith::IntSet::Vector(op->index); + for (const auto& index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } e.type = kWrite; e.scope = scope; curr_stmt_.access.emplace_back(std::move(e)); @@ -151,8 +164,12 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); for (AccessEntry& e : s.access) { if (e.buffer.defined()) { - ICHECK(e.touched.defined()); - e.touched = arith::EvalSet(e.touched, relax_map); + ICHECK(e.touched.size()); + Array new_touched; + for (const auto& touched : e.touched) { + new_touched.push_back(arith::EvalSet(touched, relax_map)); + } + e.touched = std::move(new_touched); } } } @@ -213,7 +230,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(op->args[1]); - e.touched = arith::IntSet::FromRange(Range::FromMinExtent(offset, extent)); + e.touched = {arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; e.scope = scope; if (flag->value & 1) { e.type = kRead; diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 9dc4c923b054..a48ee73f17fc 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -61,8 +61,11 @@ class StorageAccessVisitor : public StmtExprVisitor { Var buffer = NullValue(); /*! \brief The access data type */ DataType dtype; - /*! \brief The touched access range */ - arith::IntSet touched; + /*! \brief The touched access range + * + * Has one IntSet for each index in the buffer being accessed. + */ + Array touched; /*! \brief The type of access */ AccessType type; /*! \brief The storage scope */ @@ -80,6 +83,8 @@ class StorageAccessVisitor : public StmtExprVisitor { // override visitor pattern void VisitExpr_(const LoadNode* op) final; void VisitStmt_(const StoreNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const EvaluateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const ForNode* op) final; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 0293519c9945..b4e770363aaa 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1356,7 +1356,9 @@ class StorageFlattener : public StmtExprMutator { for (size_t i = 1; i < shape.size(); ++i) { bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); } - return bound; + Array bounds{bound}; + + return Call(DataType::Handle(), builtin::tvm_tuple(), bounds); } const BufferEntry& GetBufferEntry(Buffer buffer) { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index df61a6ce978c..b1cb35341840 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -256,6 +256,8 @@ class InplaceOpVerifier : public StmtExprVisitor { VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { VisitStmt_(static_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(static_cast(stmt)); } else { return false; } @@ -282,17 +284,21 @@ class InplaceOpVerifier : public StmtExprVisitor { } void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { ++mem_nest_; - this->VisitExpr(op->index); + for (const auto& index : op->indices) { + this->VisitExpr(index); + } --mem_nest_; - if (op->buffer_var.get() == dst_) { + if (op->buffer->data.get() == dst_) { store_ = op; this->VisitExpr(op->value); - this->VisitExpr(op->predicate); store_ = nullptr; } else { this->VisitExpr(op->value); - this->VisitExpr(op->predicate); } } @@ -306,7 +312,11 @@ class InplaceOpVerifier : public StmtExprVisitor { } void VisitExpr_(const LoadNode* op) final { - const VarNode* buf = op->buffer_var.get(); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { + const VarNode* buf = op->buffer->data.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { result_ = false; @@ -318,11 +328,19 @@ class InplaceOpVerifier : public StmtExprVisitor { return; } if (src_ == buf) { - if (store_ == nullptr || store_->value.dtype() != op->dtype || - !tir::ExprDeepEqual()(store_->index, op->index)) { + if (store_ == nullptr || store_->value.dtype() != op->dtype) { result_ = false; return; } + ICHECK_EQ(store_->indices.size(), op->indices.size()) + << "Store/Load occur to the same buffer " << buf->name_hint + << " with differing number of indices"; + for (size_t i = 0; i < store_->indices.size(); i++) { + if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) { + result_ = false; + return; + } + } } ++mem_nest_; StmtExprVisitor::VisitExpr_(op); @@ -340,7 +358,7 @@ class InplaceOpVerifier : public StmtExprVisitor { // it is not safe to inplace when there is nested load like A[B[i]] int mem_nest_{0}; // The current store to be inspected - const StoreNode* store_{nullptr}; + const BufferStoreNode* store_{nullptr}; }; /* \brief Rewrite and merge memory allocation. diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index b2b27c9dd618..ce3f8fd3e3ac 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -177,22 +177,54 @@ class ThreadSyncPlanner : public StorageAccessVisitor { private: // find conflicting entry in vec. - bool FindConflict(const std::vector& vec, const AccessEntry& e, bool loop_carry) { - for (const AccessEntry& x : vec) { - if (x.buffer.same_as(e.buffer)) { - // Assumes no race between threads - // Same index value means no conflicts - // TODO(tqchen) more standard set based testing. - if (e.touched.IsSinglePoint() && x.touched.IsSinglePoint()) { - if (ExprDeepEqual()(e.touched.PointValue(), x.touched.PointValue())) continue; - } - if (x.double_buffer_write && e.type == kRead && !loop_carry) continue; + bool FindConflict(const std::vector& prev, const AccessEntry& curr, + bool loop_carry) { + for (const AccessEntry& x : prev) { + if (FindConflict(x, curr, loop_carry)) { return true; } } return false; } + bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) { + // Access to different buffers does not conflict. + if (!prev.buffer.same_as(curr.buffer)) { + return false; + } + + // Assumes no race between threads + // Same index value means no conflicts + // TODO(tqchen) more standard set based testing. + bool has_same_index = true; + for (size_t i = 0; i < prev.touched.size(); i++) { + const auto& prev_intset = prev.touched[i]; + const auto& curr_intset = curr.touched[i]; + + bool provably_same_index = + prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint() && + ExprDeepEqual()(prev_intset.PointValue(), curr_intset.PointValue()); + + if (!provably_same_index) { + has_same_index = false; + break; + } + } + if (has_same_index) { + return false; + } + + // If this is a read into a double buffer that was previously + // swapped out, then it doesn't conflict. + if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { + return false; + } + + // If nothing else allows sharing the same buffer, then they are + // in conflict. + return true; + } + private: // synchronization scope StorageScope sync_scope_; diff --git a/tests/cpp/tir_analysis_side_effect.cc b/tests/cpp/tir_analysis_side_effect.cc index a59e4a7f8c05..bd7d7805e7aa 100644 --- a/tests/cpp/tir_analysis_side_effect.cc +++ b/tests/cpp/tir_analysis_side_effect.cc @@ -25,10 +25,9 @@ TEST(SimplePasses, SideEffect) { using namespace tvm; - auto A = tir::Var("A", DataType::Handle()); + auto buf = tir::decl_buffer({16}, DataType::Float(32)); auto i = tir::Var("i", DataType::Int(32)); - ICHECK(tir::SideEffect(tir::Load(DataType::Float(32), A, i, tir::const_true(1))) == - tir::CallEffectKind::kReadState); + ICHECK(tir::SideEffect(tir::BufferLoad(buf, {i})) == tir::CallEffectKind::kReadState); ICHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure); ICHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) == tir::CallEffectKind::kUpdateState); From dba20317ec23656e958df09ca70d0cac15933b7f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 Nov 2021 11:14:53 -0600 Subject: [PATCH 008/177] Replacing Store/Load in lowering/legalization passes. - MakeCrossThreadReduction - CacheReadRewriter/CacheWriteRewriter - InjectVirtualThread - InjectDoubleBuffer - InjectCopyIntrin - LowerWarpMemory - LowerThreadAllreduce - LowerThreadAllreduce - LowerCustomDatatypes - LowerTVMBuiltin - CoProcSync - MergeDynamicSharedMemAllocations - VectorizeLoop - BF16Legalize --- src/te/operation/cross_thread_reduction.cc | 46 ++--- .../schedule/primitive/cache_read_write.cc | 30 +-- src/tir/transforms/bf16_legalize.cc | 24 +-- src/tir/transforms/coproc_sync.cc | 16 +- src/tir/transforms/inject_copy_intrin.cc | 40 +++- src/tir/transforms/inject_double_buffer.cc | 100 ++++++--- src/tir/transforms/inject_virtual_thread.cc | 149 ++++++++----- src/tir/transforms/lower_custom_datatypes.cc | 70 +++++-- src/tir/transforms/lower_thread_allreduce.cc | 195 ++++++++++-------- src/tir/transforms/lower_tvm_builtin.cc | 183 +++++++++++++--- src/tir/transforms/lower_warp_memory.cc | 107 +++++++--- ...merge_dynamic_shared_memory_allocations.cc | 81 ++++++-- src/tir/transforms/rewrite_unsafe_select.cc | 6 +- src/tir/transforms/vectorize_loop.cc | 184 +++++++++++++---- 14 files changed, 848 insertions(+), 383 deletions(-) diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 2ed5fd4029a2..e419377e7664 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -134,29 +134,25 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // If we load from and then store into the same res_handles in the thread_allreduce intrinsic, // something goes wrong, so we use an extra variable here for normal reduction. - std::vector normal_res_handles; + std::vector normal_res_buffers; std::vector normal_init, normal_update; if (!normal_red.empty()) { - normal_res_handles.reserve(size); + normal_res_buffers.reserve(size); normal_init.reserve(size); normal_update.resize(size); const CommReducerNode* combiner = reduces[0]->combiner.as(); ICHECK(combiner); Array lhs; for (size_t i = 0; i < size; ++i) { - DataType t = reduces[i]->dtype; - normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t), "local")); - lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); + normal_res_buffers.push_back( + decl_buffer({1}, reduces[i]->dtype, "normal_reduce_temp" + std::to_string(i), "local")); + lhs.push_back(BufferLoad(normal_res_buffers[i], {0})); } Array init_value = combiner->identity_element; Array update_value = (*combiner)(lhs, reduces[0]->source); for (size_t i = 0; i < size; ++i) { - DataType t = reduces[i]->dtype; - normal_init.emplace_back( - Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); - normal_update.emplace_back( - Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + normal_init.emplace_back(BufferStore(normal_res_buffers[i], init_value[i], {0})); + normal_update.emplace_back(BufferStore(normal_res_buffers[i], update_value[i], {0})); } } @@ -164,8 +160,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, freduce_args.push_back(make_const(DataType::UInt(32), static_cast(size))); for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { - DataType t = reduces[i]->dtype; - freduce_args.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); + freduce_args.push_back(BufferLoad(normal_res_buffers[i], {0})); } else { freduce_args.push_back(reduces[0]->source[i]); } @@ -174,12 +169,15 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // No constraints on the thread reduction step. It may have redundent // computation for rare cases. TODO(tvm-team): revisit this. freduce_args.push_back(const_true(1)); - std::vector res_handles(size); + std::vector res_buffers(size); for (size_t idx = 0; idx < size; ++idx) { - DataType dtype = reduces[idx]->dtype; - res_handles[idx] = - Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); - freduce_args.push_back(res_handles[idx]); + res_buffers[idx] = + decl_buffer({1}, reduces[idx]->dtype, "reduce_temp" + std::to_string(idx), "local"); + // Make a BufferLoad object so that we can pass the entire Buffer + // object through to LowerThreadAllreduce. The index here is + // unused. + PrimExpr dummy_load = BufferLoad(res_buffers[idx], {0}); + freduce_args.push_back(dummy_load); } for (IterVar iv : stage->leaf_iter_vars) { @@ -216,18 +214,18 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { - DataType t = reduces[idx]->dtype; - assigns[idx] = ProducerStore(stage->op.output(idx), - Load(t, res_handles[idx], 0, const_true(t.lanes())), args); + assigns[idx] = ProducerStore(stage->op.output(idx), BufferLoad(res_buffers[idx], {0}), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + const auto& res_buffer = res_buffers[idx - 1]; + body = Allocate(res_buffer->data, res_buffer->dtype, res_buffer->shape, const_true(), body); if (!normal_red.empty()) { - body = - Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + const auto& normal_res_buffer = normal_res_buffers[idx - 1]; + body = Allocate(normal_res_buffer->data, normal_res_buffer->dtype, normal_res_buffer->shape, + const_true(), body); } } body = Substitute(body, value_map); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 05695a8c4dc4..1c66cbfbfd0e 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -454,13 +454,9 @@ class CacheReadRewriter : public StmtExprMutator { return ExprMutator::VisitExpr_(load); } - PrimExpr VisitExpr_(const LoadNode* load) final { - if (load->buffer_var.same_as(info_->read_buffer->data)) { - ObjectPtr n = make_object(*load); - n->buffer_var = info_->write_buffer->data; - return PrimExpr(n); - } - return ExprMutator::VisitExpr_(load); + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -573,22 +569,14 @@ class CacheWriteRewriter : public StmtExprMutator { return ExprMutator::VisitExpr_(load); } - PrimExpr VisitExpr_(const LoadNode* load) final { - if (load->buffer_var.same_as(info_->write_buffer->data)) { - ObjectPtr n = make_object(*load); - n->buffer_var = info_->read_buffer->data; - return PrimExpr(n); - } - return ExprMutator::VisitExpr_(load); + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } - Stmt VisitStmt_(const StoreNode* store) final { - if (store->buffer_var.same_as(info_->write_buffer->data)) { - ObjectPtr n = make_object(*store); - n->buffer_var = info_->read_buffer->data; - return Stmt(n); - } - return StmtMutator::VisitStmt_(store); + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const VarNode* op) final { diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 76845cbebd2a..e398f75cb0fe 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -269,16 +269,8 @@ class BF16LowerRewriter : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - - auto it = var_remap_.find(op->buffer_var); - if (it != var_remap_.end()) { - return Store(it->second, op->value, op->index, op->predicate); - } else { - return ret; - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -294,16 +286,8 @@ class BF16LowerRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - - if (op->dtype.is_bfloat16()) { - auto it = var_remap_.find(op->buffer_var); - ICHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped"; - return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate); - } else { - return ret; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } PrimExpr VisitExpr_(const FloatImmNode* op) final { diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index eb1ade173a38..f3a9f990599f 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -39,18 +39,24 @@ namespace tir { class CoProcTouchedBuffer : public StmtExprVisitor { public: void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + void VisitExpr_(const BufferLoadNode* op) final { if (in_scope_) { - touched_[op->buffer_var.get()].coproc = true; + touched_[op->buffer->data.get()].coproc = true; } else { - touched_[op->buffer_var.get()].normal = true; + touched_[op->buffer->data.get()].normal = true; } StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const StoreNode* op) final { + void VisitStmt_(const BufferStoreNode* op) final { if (in_scope_) { - touched_[op->buffer_var.get()].coproc = true; + touched_[op->buffer->data.get()].coproc = true; } else { - touched_[op->buffer_var.get()].normal = true; + touched_[op->buffer->data.get()].normal = true; } StmtExprVisitor::VisitStmt_(op); } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index f99cbd5b5a05..32845f4ca60d 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -63,7 +63,7 @@ class CopyIntrinInjector : public StmtMutator { loops.push_back(op); body = op->body; } - const StoreNode* store = body.as(); + auto store = body.as(); if (store == nullptr) return false; // Expr sel_cond, sel_true_value, sel_false_value; // match select or if @@ -72,17 +72,17 @@ class CopyIntrinInjector : public StmtMutator { select(sel_cond, sel_true_value, sel_false_value).Match(store->value); const CastNode* cast = store->value.as(); - const LoadNode* load = store->value.as(); + auto load = store->value.as(); if (0 == loops.size()) { ICHECK(!has_cond); } // for now only support true condition matching if (has_cond) { - load = sel_true_value.Eval().as(); + load = sel_true_value.Eval().as(); } // cast can be part of the pattern if (cast != nullptr) { - load = cast->value.as(); + load = cast->value.as(); } if (load == nullptr) return false; if (load->dtype.lanes() != 1) return false; @@ -90,8 +90,17 @@ class CopyIntrinInjector : public StmtMutator { for (const ForNode* op : loops) { loop_vars.push_back(op->loop_var); } - Array store_strides = arith::DetectLinearEquation(store->index, loop_vars); - Array load_strides = arith::DetectLinearEquation(load->index, loop_vars); + // TODO(Lunderberg): Move this pass to be before + // StorageFlatten/FlattenBuffer. That will simplify the + // implementation, since the pre-flattened indices/strides can be + // used directly. + ICHECK((store->indices.size() == 1) && (load->indices.size() == 1)) + << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + Array store_strides = arith::DetectLinearEquation(store->indices[0], loop_vars); + Array load_strides = arith::DetectLinearEquation(load->indices[0], loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; const size_t loop_var_size = loop_vars.size(); @@ -145,10 +154,21 @@ class CopyIntrinInjector : public StmtMutator { src_strides.push_back(make_const(DataType::Int(32), 1)); dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault); - Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, - load->buffer_var->name_hint, 0, 0, kDefault); + Buffer dst = store->buffer; + { + auto writer = dst.CopyOnWrite(); + writer->shape = dst_shape; + writer->strides = dst_strides; + writer->elem_offset = store_strides[loop_var_size]; + } + + Buffer src = load->buffer; + { + auto writer = src.CopyOnWrite(); + writer->shape = src_shape; + writer->strides = src_strides; + writer->elem_offset = src_elem_offset; + } *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); ICHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 0b45bde28dfe..d39538c0faf0 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -103,26 +103,27 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + const VarNode* buf = op->buffer_var.as(); auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { it->second.scope = GetPtrStorageScope(op->buffer_var); - it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - Array new_extents{make_const(op->extents[0].dtype(), 2)}; - for (PrimExpr e : op->extents) { - new_extents.push_back(e); - } + + ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + it->second.stride = op->extents[0]; + + Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; alloc_nest.emplace_back( Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; } else { - return StmtExprMutator::VisitStmt_(op); + return stmt; } } @@ -170,34 +171,77 @@ class DoubleBufferInjector : public StmtExprMutator { return stmt; } + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - auto it = dbuffer_info_.find(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto it = dbuffer_info_.find(node->buffer->data.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; ICHECK(in_double_buffer_scope_); - ICHECK(e.stride.defined()); - return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, - op->predicate); - } else { - return stmt; + ICHECK(e.switch_write_var.defined()); + + ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetRemappedBuffer(node->buffer, e.stride); + writer->indices = {e.switch_write_var * e.stride + node->indices[0]}; } + + return std::move(node); } - PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto it = dbuffer_info_.find(op->buffer_var.get()); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + + auto it = dbuffer_info_.find(node->buffer->data.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; - ICHECK(e.stride.defined()); ICHECK(e.switch_read_var.defined()); - return Load(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, - op->predicate); - } else { - return expr; + + ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetRemappedBuffer(node->buffer, e.stride); + writer->indices = {e.switch_read_var * e.stride + node->indices[0]}; } + + return std::move(node); + } + + Buffer GetRemappedBuffer(Buffer buf, PrimExpr stride) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; + } + + ICHECK(stride.defined()); + // TODO(Lunderberg): Move this pass to before + // StorageFlatten/FlattenBuffer. That will simplify the + // implementation, to be the insertion of a new dimension for the + // buffer, rather than adjusting the other indices. + ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + auto writer = buf.CopyOnWrite(); + writer->shape = {buf->shape[0] * stride}; + + buf_remap_[key] = buf; + return buf; } PrimExpr VisitExpr_(const VarNode* op) final { @@ -261,6 +305,8 @@ class DoubleBufferInjector : public StmtExprMutator { std::unordered_map > loop_pre_; // The allocation size of the buffer std::unordered_map dbuffer_info_; + // The updated Buffer objects + std::unordered_map buf_remap_; }; namespace transform { diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 59391554948e..f6ce88cf1707 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -50,7 +50,10 @@ class ExprTouched final : public StmtExprVisitor { StmtExprVisitor::VisitStmt(n); } void VisitExpr_(const LoadNode* op) final { - HandleUseVar(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + void VisitExpr_(const BufferLoadNode* op) final { + HandleUseVar(op->buffer->data.get()); StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } @@ -211,20 +214,6 @@ class VTInjector : public StmtExprMutator { PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const { return index + var_ * alloc_extent; } - // Load - PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (touched_var_.count(op->buffer_var.get())) { - visit_touched_var_ = true; - } - auto it = alloc_remap_.find(op->buffer_var.get()); - if (it != alloc_remap_.end()) { - return Load(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate); - } else { - return expr; - } - } // Expression. PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { @@ -237,7 +226,8 @@ class VTInjector : public StmtExprMutator { PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); - offset = stride * var_ + offset; + offset = RewriteIndex(offset, stride); + return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::tvm_context_id())) { return allow_share_ ? GetRef(op) : var_; @@ -249,21 +239,61 @@ class VTInjector : public StmtExprMutator { trigger_base_inject_ = !allow_share_; return StmtExprMutator::VisitStmt_(op); } + // Load + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } // Store Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (touched_var_.count(op->buffer_var.get())) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + // BufferLoad + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + // BufferStore + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + trigger_base_inject_ = !allow_share_; + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + if (touched_var_.count(node->buffer->data.get())) { visit_touched_var_ = true; } - trigger_base_inject_ = !allow_share_; - auto it = alloc_remap_.find(op->buffer_var.get()); + + auto it = alloc_remap_.find(node->buffer->data.get()); if (it != alloc_remap_.end()) { - return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate); - } else { - return stmt; + ICHECK_EQ(node->indices.size(), 1) + << "InjectVirtualThread expects rewritten allocations to be flat memory."; + auto writer = node.CopyOnWrite(); + writer->buffer = GetRemappedBuffer(node->buffer, it->second); + writer->indices = {RewriteIndex(node->indices[0], it->second)}; } + + return node; } + + Buffer GetRemappedBuffer(Buffer buf, PrimExpr alloc_extent) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; + } + + ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened."; + auto writer = buf.CopyOnWrite(); + writer->shape = {buf->shape[0] * alloc_extent}; + + buf_remap_[key] = buf; + return buf; + } + // Attribute Stmt VisitStmt_(const AttrStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); @@ -361,46 +391,44 @@ class VTInjector : public StmtExprMutator { } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { + Allocate node = GetRef(op); + PrimExpr condition = this->VisitExpr(op->condition); + + Array extents = op->extents; + extents.MutateByApply([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); + if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } - bool changed = false; - Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); - if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); - } - if (!new_ext.same_as(op->extents[i])) changed = true; - extents.push_back(new_ext); - } visit_touched_var_ = false; - Stmt body; - // always rewrite if not allow sharing. + // Rewrite the buffer if its shape or any value stored in it + // depends on the virtual thread var. If `allow_share_` is false, + // then the buffer is always rewritten, even if separate virtual + // threads only read from the buffer. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); - Array other; - other.push_back(make_const(op->extents[0].dtype(), num_threads_)); - for (PrimExpr e : extents) { - other.push_back(e); - } - extents = other; - changed = true; - // mark this buffer get touched. + + // TODO(Lunderberg): Move pass to apply before + // StorageFlatten/FlattenBuffer. Would rewrite the Buffer to + // add the injected virtual thread as the first index. + ICHECK_EQ(extents.size(), 1) + << "InjectVirtualThread expects rewritten allocations to be flat memory."; + PrimExpr stride = extents[0]; + extents = {stride * num_threads_}; + + // Mark the buffer var as touched. BufferLoad/BufferStore should + // access locations at `current_index + stride*vthread_var`. alloc_remap_[op->buffer_var.get()] = stride; - // Mutate the body. - body = this->VisitStmt(op->body); - } else { - // Mutate the body. - body = this->VisitStmt(op->body); } - if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { + + // Mutate the body. Depends on alloc_remap_. + auto body = this->VisitStmt(op->body); + + if (extents.same_as(op->extents) && body.same_as(op->body) && + condition.same_as(op->condition)) { return GetRef(op); } else { return Allocate(op->buffer_var, op->dtype, extents, condition, body); @@ -455,8 +483,21 @@ class VTInjector : public StmtExprMutator { const std::unordered_set& touched_var_; // Whether allow shareding. bool allow_share_; - // The allocations that get touched -> extent + /* \brief The allocations that get touched -> extent + * + * Maps from the buffer_var of an allocate node to the original + * extent of the allocation. Used when rewriting the indices of + * BufferLoad/BufferStore. + */ std::unordered_map alloc_remap_; + /*! \brief Map of buffers that are modified. + * + * Buffers allocated or written to within the virtual thread loop + * must have one copy per virtual thread. This is done by enlarging + * the allocated buffer size, then modifying the indices at which + * each virtual thread accesses the buffer. + */ + std::unordered_map buf_remap_; }; class VirtualThreadInjector : public StmtMutator { diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 21f1b18d523b..16afa1133f68 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -103,32 +103,59 @@ class CustomDatatypesLowerer : public StmtExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* load) final { - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); - PrimExpr expr = StmtExprMutator::VisitExpr_(load); - load = expr.as(); - if (to_be_lowered) { - auto new_load_type = DataType::UInt(load->dtype.bits()); - auto buffer_var = load->buffer_var; - auto it = var_remap_.find(buffer_var); - if (it != var_remap_.end()) { - buffer_var = it->second; - } - return Load(new_load_type, buffer_var, load->index, load->predicate); - } - return expr; + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } - auto it = var_remap_.find(op->buffer_var); - if (it != var_remap_.end()) { - return Store(it->second, op->value, op->index, op->predicate); - } else { - return ret; + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; } + + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + auto key = buf; + auto cache_it = buf_remap_.find(key); + if (cache_it != buf_remap_.end()) { + return cache_it->second; + } + + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code()); + + if (to_be_lowered) { + auto new_load_type = DataType::UInt(buf->dtype.bits()); + auto writer = buf.CopyOnWrite(); + writer->dtype = new_load_type; + + auto var_it = var_remap_.find(buf->data); + if (var_it != var_remap_.end()) { + writer->data = var_it->second; + } + } + + buf_remap_[key] = buf; + return buf; } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -200,6 +227,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { std::string target_; // remap buffer vars std::unordered_map var_remap_; + std::unordered_map buf_remap_; }; namespace transform { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 6f7c09cdcf2d..5d17aae1f2b8 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -109,25 +109,43 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return stmt; } } + PrimExpr VisitExpr_(const LoadNode* op) final { - auto it = load_remap_.find(op->buffer_var.get()); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto it = load_remap_.find(op->buffer.get()); if (it != load_remap_.end()) { - ICHECK(is_zero(op->index)); + for (const auto& index : op->indices) { + ICHECK(is_zero(index)); + } return it->second; } else { return StmtExprMutator::VisitExpr_(op); } } - Stmt VisitStmt_(const StoreNode* op) final { - auto it = store_remap_.find(op->buffer_var.get()); + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto it = store_remap_.find(store->buffer.get()); if (it != store_remap_.end()) { - ICHECK(is_zero(op->index)); - auto value = StmtExprMutator::VisitExpr(op->value); - return Store(it->second, value, 0, op->predicate); - } else { - return StmtExprMutator::VisitStmt_(op); + for (const auto& index : op->indices) { + ICHECK(is_zero(index)); + } + + auto writer = store.CopyOnWrite(); + writer->buffer = it->second; } + + return std::move(store); } std::unordered_map new_storage_scopes_; @@ -164,11 +182,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } types[idx] = values[idx].dtype(); } - std::vector buffers(size); + std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - const VarNode* buffer = call->args[2 + size + idx].as(); - ICHECK(buffer); - buffers[idx] = buffer; + auto dummy_load = Downcast(call->args[2 + size + idx]); + buffers[idx] = dummy_load->buffer; } std::unordered_set reduce_set; @@ -219,8 +236,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); PrimExpr group_index = FlattenThread(vpar, &group_extent); std::vector seq; - std::vector shared_bufs(size); - std::vector local_vars; + std::vector shared_buffer_vars(size); + std::vector shared_bufs(size); + std::vector local_bufs; // // This is an optimization. For small reduction sizes, it may be beneficial // for a single warp to performance the entire reduction. No trips to shared @@ -245,19 +263,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // This is the index to the reduction variable, one reduction // variable per warp. Local scope seems easier to reason without // relying on a pattern match pass to fix it later. - PrimExpr index(0); + Array zero_indices = {0}; for (size_t idx = 0; idx < size; ++idx) { - Type ptr_type = PointerType(PrimType(types[idx])); - shared_bufs[idx] = Var("red_buf" + std::to_string(idx), ptr_type); + Array shape = {1}; + + Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx)); + Var buffer_var = buffer->data; + + shared_buffer_vars[idx] = buffer_var; + shared_bufs[idx] = buffer; + PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred)); + seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices)); - // Uses a local variable to store the shuffled data. - // Later on, this allocation will be properly attached to this statement. - Var var("t" + std::to_string(idx), ptr_type); - Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0)); - local_vars.push_back(s); + // Uses a local variable to store the shuffled data. Later + // on, an allocation will be built for this local variable. + local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx))); } // The mask for this reducer, as this reducer may sit inside @@ -265,15 +287,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // active channels. // DataType mask_dtype = DataType::UInt(32); - Var mask_var("mask", PointerType(PrimType(mask_dtype))); + Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask"); { - PrimExpr pred = const_true(1); PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); - seq.emplace_back(Store(mask_var, mask, index, pred)); - // Push allocation with an empty body. Later this will be fixed - // when the entire body is ready. - auto stmt = Allocate(mask_var, mask_dtype, {PrimExpr(1)}, pred, Evaluate(0)); - local_vars.push_back(stmt); + seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices)); + // Push the buffer description. Later this will have an + // allocation built for it. + local_bufs.push_back(mask_buffer); } // Emit reductions within a warp. @@ -281,9 +301,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Load reduction values, no synchronization needed. Array a, b; for (size_t i = 0; i < size; ++i) { - Var var = shared_bufs[i]; - PrimExpr pred = const_true(types[i].lanes()); - PrimExpr val = Load(types[i], var, index, pred); + Buffer shared_buf = shared_bufs[i]; + BufferLoad val(shared_buf, zero_indices); + ICHECK_EQ(val->dtype, types[i]); a.push_back(val); // __shfl_*sync calls shall not appear in if_then_else expressions @@ -299,12 +319,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The former may cause dead lock as there is a divergent // branch with a warp sync call inside. // - PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset); - const AllocateNode* repl = local_vars[i].as(); - Stmt s = Store(repl->buffer_var, other, index, pred); + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset); + Buffer local_buf = local_bufs[i]; + Stmt s = BufferStore(local_buf, other, zero_indices); seq.push_back(s); - PrimExpr load = Load(types[i], repl->buffer_var, index, pred); + BufferLoad load = BufferLoad(local_buf, zero_indices); + ICHECK_EQ(load->dtype, types[i]); b.push_back(load); } @@ -314,9 +335,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Store the reduction result to itself. std::vector stores(size); for (size_t i = 0; i < size; ++i) { - Var var = shared_bufs[i]; - PrimExpr pred = const_true(types[i].lanes()); - stores[i] = Store(var, ret[i], index, pred); + Buffer buf = shared_bufs[i]; + stores[i] = BufferStore(buf, ret[i], zero_indices); } seq.push_back(SeqStmt::Flatten(stores)); } @@ -326,34 +346,34 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // uniformmly writting the same result. // for (size_t i = 0; i < size; ++i) { - Var var = shared_bufs[i]; - PrimExpr pred = const_true(types[i].lanes()); - PrimExpr val = Load(types[i], var, index, pred); - PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0); - seq.push_back(Store(var, splat, index, pred)); + Buffer buf = shared_bufs[i]; + PrimExpr val = BufferLoad(buf, zero_indices); + ICHECK_EQ(val->dtype, types[i]); + PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, 0); + seq.push_back(BufferStore(buf, splat, zero_indices)); } // Update existing allocations. for (size_t i = 0; i < size; ++i) { - ICHECK(!load_remap_.count(buffers[i])); + ICHECK(!load_remap_.count(buffers[i].get())); PrimExpr pred = const_true(types[i].lanes()); - Var var = shared_bufs[i]; - load_remap_[buffers[i]] = Load(types[i], var, index, pred); - store_remap_[buffers[i]] = var; + Buffer buf = shared_bufs[i]; + PrimExpr val = BufferLoad(buf, zero_indices); + ICHECK_EQ(val->dtype, types[i]); + load_remap_[buffers[i].get()] = val; + store_remap_[buffers[i].get()] = buf; Array extents{PrimExpr(1)}; - auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); - alloc_remap_[buffers[i]] = node; + auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0)); + alloc_remap_[buffers[i]->data.get()] = node; warp_allocs_.insert(node.get()); } } else { int threadx_extent = 1; if (reduce_extent == 1) { // special case, no reduction is needed. - std::vector stores(size); + std::vector stores; for (size_t i = 0; i < size; ++i) { - PrimExpr pred = const_true(types[i].lanes()); - Var buffer_var = Downcast(call->args[2 + size + i]); - stores[i] = Store(buffer_var, values[i], 0, pred); + stores.push_back(BufferStore(buffers[i], values[i], {0})); } return SeqStmt::Flatten(stores); } @@ -365,35 +385,37 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf" + std::to_string(idx), PointerType(PrimType(types[idx]))); + Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx)); + + shared_bufs[idx] = buffer; + shared_buffer_vars[idx] = buffer->data; + PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(Store(shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); + seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], + {BufIndex(reduce_index, group_index, reduce_extent)})); } seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, reduce_extent, threadx_extent)); for (size_t idx = 0; idx < size; ++idx) { - ICHECK(!load_remap_.count(buffers[idx])); + ICHECK(!load_remap_.count(buffers[idx].get())); PrimExpr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = - Load(types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = - Allocate(shared_bufs[idx], types[idx], + BufferLoad load(shared_bufs[idx], + {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); + ICHECK_EQ(load->dtype, types[idx]); + load_remap_[buffers[idx].get()] = load; + alloc_remap_[buffers[idx]->data.get()] = + Allocate(shared_bufs[idx]->data, types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); - store_remap_[buffers[idx]] = shared_bufs[idx]; + store_remap_[buffers[idx].get()] = shared_bufs[idx]; } } // Fix all local allocations as all statements are built. Stmt body = SeqStmt::Flatten(seq); - for (auto var : local_vars) { - const AllocateNode* repl = var.as(); - if (repl) { - body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - new_storage_scopes_[repl->buffer_var.get()] = "local"; - } + for (Buffer buf : local_bufs) { + body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body); + new_storage_scopes_[buf->data.get()] = "local"; } return body; @@ -401,8 +423,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, - const Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, - int reduce_extent, int threadx_extent) { + const Array& shared_bufs, PrimExpr reduce_index, + PrimExpr group_index, int reduce_extent, int threadx_extent) { // Get next power of two int reduce_align = 1; while (reduce_extent > reduce_align) { @@ -417,10 +439,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto fload = [&](int offset) { Array a, b; for (size_t i = 0; i < size; ++i) { - b.push_back(Load(types[i], shared_bufs[i], - BufIndex(reduce_index + offset, group_index, reduce_extent), - const_true())); - a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true())); + BufferLoad b_load(shared_bufs[i], + {BufIndex(reduce_index + offset, group_index, reduce_extent)}); + ICHECK_EQ(b_load->dtype, types[i]); + b.push_back(b_load); + + BufferLoad a_load(shared_bufs[i], {buf_index}); + ICHECK_EQ(a_load->dtype, types[i]); + a.push_back(a_load); } Array ret = (*combiner)(a, b); return ret; @@ -428,7 +454,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto fstore = [&](const Array& ret) { std::vector stores(size); for (size_t i = 0; i < size; ++i) { - stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true()); + stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index}); } return SeqStmt::Flatten(stores); }; @@ -534,10 +560,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) { - PrimExpr pred = const_true(1); - PrimExpr index(0); - PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); + PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, int delta_or_lane) { + Array indices = {0}; + PrimExpr mask = BufferLoad(mask_buffer, indices); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; return Call(val.dtype(), op, args); @@ -599,9 +624,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // The store remap - std::unordered_map store_remap_; + std::unordered_map store_remap_; // Allocate remap std::unordered_map alloc_remap_; // Allocate from warp reductions diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index a5ecf4ba8296..00971f7c3a98 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -34,16 +34,125 @@ namespace tvm { namespace tir { +class StackSizeChecker : public StmtExprVisitor { + public: + struct StackSizes { + // If a tvm_stack_make_shape call has no arguments, it is still + // valid and represents a scalar shape (). Therefore, -1 is used + // to represent "no shape arguments exist", while 0 represents + // "shape arguments exist, all of which are size 0". + int64_t shape_stack{-1}; + uint64_t array_stack{0}; + uint64_t arg_stack{0}; + }; + + static StackSizes Check(Stmt stmt) { + StackSizeChecker visitor; + visitor.VisitStmt(stmt); + return visitor.max_stack_; + } + + private: + void VisitStmt_(const ForNode* op) final { + if (op->kind == ForKind::kParallel) { + // Parallel for loops have their own stack and allocations, so + // stop the recursion here. + return; + } else { + this->VisitStmt(op->body); + } + } + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_call_packed())) { + return MakeCallPacked(op, /* use_string_lookup */ true); + } else if (op->op.same_as(builtin::tvm_call_cpacked())) { + return MakeCallPacked(op, /* use_string_lookup */ false); + } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { + return MakeCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { + return MakeShape(op); + } else if (op->op.same_as(builtin::tvm_stack_make_array())) { + return MakeArray(op); + } else { + return StmtExprVisitor::VisitExpr_(op); + } + } + // call shape + void MakeShape(const CallNode* op) { + // if args.size() == 0, it is still valid and represents a scalar + // shape (). Therefore, -1 is used to represent "no shape + // arguments exist", while 0 represents "shape arguments exist, + // all of which are size 0". + if (current_stack_.shape_stack == -1) { + current_stack_.shape_stack = 0; + } + current_stack_.shape_stack += op->args.size(); + StmtExprVisitor::VisitExpr_(op); + } + // make array + void MakeArray(const CallNode* op) { + current_stack_.array_stack += 1; + StmtExprVisitor::VisitExpr_(op); + } + // call packed. + void MakeCallPacked(const CallNode* op, bool use_string_lookup) { + StackSizes restore_stack = current_stack_; + + size_t arg_count = op->args.size(); + + // cpacked expects a resource_handle parameter + if (!use_string_lookup) { + arg_count--; + } + + current_stack_.arg_stack += arg_count; + // Specially handle the buffer packed intrinsic + StmtExprVisitor::VisitExpr_(op); + // Record the amount of stack space needed, then reset the stack + // position to its previous location. + UpdateMaxStack(); + current_stack_ = restore_stack; + } + + void MakeCallTracePacked(const CallNode* op) { + StackSizes restore_stack = current_stack_; + + size_t args_size = op->args.size(); + ICHECK_GT(args_size, 0); + current_stack_.arg_stack += args_size; + + StmtExprVisitor::VisitExpr_(op); + // Record the amount of stack space needed, then reset the stack + // position to its previous location. + UpdateMaxStack(); + current_stack_ = restore_stack; + + // However, the arguments to this CallNode remain on top of the + // stack, so we can use more than one packed function's arguments + // with the one stack. + current_stack_.arg_stack = restore_stack.arg_stack + args_size - 1; + } + + void UpdateMaxStack() { + max_stack_.arg_stack = std::max(current_stack_.arg_stack, max_stack_.arg_stack); + max_stack_.shape_stack = std::max(current_stack_.shape_stack, max_stack_.shape_stack); + max_stack_.array_stack = std::max(current_stack_.array_stack, max_stack_.array_stack); + } + + StackSizes current_stack_; + StackSizes max_stack_; +}; + // Calculate the statistics of packed function. // These information are needed during codegen. class BuiltinLower : public StmtExprMutator { public: // Record stack frame for existing scope. struct AllocaScope { - Var stack_shape = Var("stack_shape", DataType::Handle()); + Buffer stack_shape; Var stack_array = Var("stack_array", DataType::Handle()); Var stack_value = Var("stack_value", DataType::Handle()); - Var stack_tcode = Var("stack_tcode", DataType::Handle()); + Buffer stack_tcode; int64_t max_shape_stack{-1}; uint64_t max_array_stack{0}; @@ -58,21 +167,41 @@ class BuiltinLower : public StmtExprMutator { // Allcoate stack frames, only at parallel-for or root. Stmt VisitBodyAndRealizeAlloca(Stmt stmt) { + // Initial check to identify maximum stack sizes. These are used + // to construct Buffer objects to hold the stack, which are then + // used when mutating. + auto max_sizes = StackSizeChecker::Check(stmt); + alloca_scope_.emplace_back(); - stmt = this->VisitStmt(stmt); - ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); - if (scope.max_shape_stack != -1) { - stmt = LetStmt(scope.stack_shape, StackAlloca("shape", scope.max_shape_stack), stmt); + + if (max_sizes.shape_stack != -1) { + scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), max_sizes.shape_stack)}, + DataType::Int(64), "stack_shape"); + stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", max_sizes.shape_stack), stmt); } - if (scope.max_array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_array_stack), stmt); + if (max_sizes.arg_stack != 0) { + scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), max_sizes.arg_stack)}, + DataType::Int(32), "stack_tcode"); + stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", max_sizes.arg_stack), stmt); + + stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", max_sizes.arg_stack), stmt); } - if (scope.max_arg_stack != 0) { - stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_arg_stack), stmt); - stmt = LetStmt(scope.stack_tcode, StackAlloca("arg_tcode", scope.max_arg_stack), stmt); + + if (max_sizes.array_stack != 0) { + stmt = LetStmt(scope.stack_array, StackAlloca("array", max_sizes.array_stack), stmt); } + + // Copy these values from the earlier search, for use in bounds + // checks. + scope.max_shape_stack = max_sizes.shape_stack; + scope.max_array_stack = max_sizes.array_stack; + scope.max_arg_stack = max_sizes.arg_stack; + + stmt = this->VisitStmt(stmt); + + ICHECK(!alloca_scope_.empty()); alloca_scope_.pop_back(); return stmt; @@ -228,10 +357,10 @@ class BuiltinLower : public StmtExprMutator { op = expr.as(); // no need to perform any store for a scalar shape for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin + i), const_true(1))); + prep_seq.emplace_back(BufferStore(scope.stack_shape, cast(DataType::Int(64), op->args[i]), + {ConstInt32(stack_begin + i)})); } - return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin); + return AddressOffset(scope.stack_shape->data, DataType::Int(64), stack_begin); } // make array PrimExpr MakeArray(const CallNode* op) { @@ -312,17 +441,16 @@ class BuiltinLower : public StmtExprMutator { arg_tcode = kTVMStr; } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; - prep_seq.emplace_back( - Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1))); + prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); } - // UPDATE stack value - scope.max_arg_stack = std::max(scope.run_arg_stack, scope.max_arg_stack); - scope.max_shape_stack = std::max(scope.run_shape_stack, scope.max_shape_stack); - scope.max_array_stack = std::max(scope.run_array_stack, scope.max_array_stack); + // Verify stack size matches earlier value. + ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack); + ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack); + ICHECK_LE(scope.run_array_stack, scope.max_array_stack); scope.run_shape_stack = restore_shape_stack; scope.run_array_stack = restore_array_stack; scope.run_arg_stack = arg_stack_begin; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode, + Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; @@ -363,19 +491,18 @@ class BuiltinLower : public StmtExprMutator { builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; - prep_seq.emplace_back( - Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1))); + prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); } - // UPDATE stack value - scope.max_arg_stack = std::max(scope.run_arg_stack, scope.max_arg_stack); - scope.max_shape_stack = std::max(scope.run_shape_stack, scope.max_shape_stack); - scope.max_array_stack = std::max(scope.run_array_stack, scope.max_array_stack); + // Verify stack size matches earlier value. + ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack); + ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack); + ICHECK_LE(scope.run_array_stack, scope.max_array_stack); scope.run_shape_stack = restore_shape_stack; scope.run_array_stack = restore_array_stack; // Update the top of the stack, so we can use more than one // packed function's arguments with the one stack. scope.run_arg_stack = arg_stack_begin + args_size - 1; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode, + Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 48313295d5ef..8ba1b7d7e374 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -114,19 +114,31 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation void VisitStmt_(const StoreNode* op) final { - if (op->buffer_var.get() == buffer_) { - if (op->value.dtype().lanes() == 1) { - UpdatePattern(op->index); - } else { - arith::PVar base; - ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index)) - << "LowerWarpMemory failed due to store index=" << op->index - << ", can only handle continuous store"; - UpdatePattern(base.Eval()); - } - } else { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { + if (op->buffer->data.get() != buffer_) { StmtVisitor::VisitStmt_(op); + return; + } + + ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has StorageFlatten (TE-based schedule) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + PrimExpr index = op->indices[0]; + if (op->value.dtype().lanes() != 1) { + arith::PVar base; + ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) + << "LowerWarpMemory failed due to store index=" << index + << ", can only handle continuous store"; + UpdatePattern(base.Eval()); + + index = base.Eval(); } + + UpdatePattern(index); } void UpdatePattern(const PrimExpr& index) { @@ -239,35 +251,62 @@ class WarpAccessRewriter : protected StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) override { - if (op->buffer_var.get() == buffer_) { - PrimExpr local_index, group; - std::tie(local_index, group) = SplitIndexByGroup(op->index); - PrimExpr new_value = VisitExpr(op->value); - return Store(op->buffer_var, new_value, local_index, op->predicate); - } else { - return StmtExprMutator::VisitStmt_(op); - } + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); } PrimExpr VisitExpr_(const LoadNode* op) override { - if (op->buffer_var.get() == buffer_) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) override { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + + if (store->buffer->data.get() == buffer_) { + ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has StorageFlatten (TE-based schedule) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + PrimExpr local_index, group; - std::tie(local_index, group) = SplitIndexByGroup(op->index); - // invariance: local index must do not contain warp id - ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) - << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index - << " local_index=" << local_index; - PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); - if (analyzer_->CanProveEqual(group, warp_index_)) { - return load_value; - } - PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); - return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), - {mask, load_value, group, width_, warp_size_}); - } else { - return StmtExprMutator::VisitExpr_(op); + std::tie(local_index, group) = SplitIndexByGroup(store->indices[0]); + + auto writer = store.CopyOnWrite(); + writer->indices = {local_index}; + } + + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (load->buffer->data.get() != buffer_) { + return std::move(load); + } + + ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has StorageFlatten (TE-based schedule) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + + PrimExpr local_index, group; + std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]); + // invariance: local index must do not contain warp id + ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) + << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] + << " local_index=" << local_index; + + auto writer = load.CopyOnWrite(); + writer->indices = {local_index}; + + if (analyzer_->CanProveEqual(group, warp_index_)) { + return std::move(load); } + + PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); + return Call(load.dtype(), builtin::tvm_warp_shuffle(), {mask, load, group, width_, warp_size_}); } + // Split the index to the two component // // local index is the index in the local diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index f3ff1f37a5da..2917ef218f0b 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -102,12 +102,17 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + } + + void VisitStmt_(const BufferStoreNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr StmtExprVisitor::VisitStmt_(op); // Add write access. - const VarNode* buf = op->buffer_var.get(); + const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); @@ -122,6 +127,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } + void VisitStmt_(const EvaluateNode* op) final { scope_.push_back(StmtEntry()); // visit subexpr @@ -133,10 +139,15 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { linear_seq_.push_back(e); } } + void VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + } + + void VisitExpr_(const BufferLoadNode* op) final { // Add write access. StmtExprVisitor::VisitExpr_(op); - const VarNode* buf = op->buffer_var.get(); + const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; @@ -145,6 +156,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { } } } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); @@ -294,22 +306,61 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - if (IsDynamicSharedMemory(op->buffer_var)) { - PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype); - PrimExpr index = StmtExprMutator::VisitExpr(op->index); - return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); - } - return StmtExprMutator::VisitExpr_(op); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } Stmt VisitStmt_(const StoreNode* op) final { - if (IsDynamicSharedMemory(op->buffer_var)) { - PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype); - PrimExpr index = StmtExprMutator::VisitExpr(op->index); - PrimExpr value = StmtExprMutator::VisitExpr(op->value); - return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + if (IsDynamicSharedMemory(node->buffer->data)) { + ICHECK_EQ(node->indices.size(), 1) + << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + Array indices = {node->indices[0] + + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; + + auto writer = node.CopyOnWrite(); + writer->buffer = GetUpdatedBuffer(node->buffer); + writer->indices = indices; } - return StmtExprMutator::VisitStmt_(op); + + return node; + } + + Buffer GetUpdatedBuffer(Buffer buffer) { + auto key = buffer.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + return it->second; + } + + if (IsDynamicSharedMemory(buffer->data)) { + ICHECK_EQ(buffer->shape.size(), 1) + << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " + << "and is to be run after " + << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + auto writer = buffer.CopyOnWrite(); + writer->data = merged_buf_var_; + } + + buffer_remap_[key] = buffer; + return buffer; } PrimExpr VisitExpr_(const CallNode* op) final { @@ -542,6 +593,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { PrimExpr merged_alloc_size_{0}; // The mapping from the original buffer var to its offset in the merged buffer std::unordered_map buffer_byte_offsets_; + // The mapping from the original buffer objects to their location in the merged buffer. + std::unordered_map buffer_remap_; // The flag indicating whether the merged buffer has been allocated bool allocated_{false}; // Locations of free ops. diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index f1286d773c2d..1ce54846aaea 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -58,10 +58,14 @@ class UnsafeExprDetector : public ExprFunctor { return true; } } - bool VisitExpr_(const LoadNode* op) { + bool VisitExpr_(const BufferLoadNode* op) { // Load is considered unsafe. return true; } + bool VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return true; + } bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); } bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); } bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index cd2d230f5775..3fed1e193de9 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -62,30 +62,89 @@ class VecAllocAccess : public StmtExprMutator { public: VecAllocAccess(const VarNode* buf, Var var, int var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} - // Load + PrimExpr VisitExpr_(const LoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op->buffer_var.get() == buf_) { - return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate); - } else { - return expr; - } + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); } - // Store + Stmt VisitStmt_(const StoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op->buffer_var.get() == buf_) { - return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return UpdateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return UpdateBufferAccess(store); + } + + private: + template + Node UpdateBufferAccess(Node node) { + // Only update the buffer that's being replaced. + if (node->buffer->data.get() != buf_) { + return node; + } + + arith::Analyzer analyzer; + + // Find/make a Buffer object with the correct updated shape. + Buffer buf; + auto it = buffer_map_.find(node->buffer.get()); + if (it != buffer_map_.end()) { + buf = it->second; } else { - return stmt; + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + Array shape = node->buffer->shape; + shape.Set(shape.size() - 1, analyzer.Simplify(shape[shape.size() - 1] * var_lanes_)); + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer, implement by appending a + // dimension to the buffer. Since it is currently after the + // flattening, the strides are not technically necessary, but + // are updated for consistency. + + // Update strides if defined. + Array strides; + for (size_t i = 0; i < strides.size(); i++) { + PrimExpr stride = strides[i]; + if (i != strides.size() - 1) { + stride *= var_lanes_; + } + strides.push_back(analyzer.Simplify(stride)); + } + + // Copy everything into the new buffer. + buf = node->buffer; + auto buf_writer = buf.CopyOnWrite(); + buf_writer->shape = shape; + buf_writer->strides = strides; + buffer_map_[buf.get()] = buf; } + + // Extend the last index by the number of lanes in the vectorized + // variable. + Array indices = node->indices; + indices.Set(indices.size() - 1, + analyzer.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); + + auto writer = node.CopyOnWrite(); + writer->buffer = buf; + writer->indices = indices; + return node; } - private: // buffer var const VarNode* buf_; + // Updated buffer objects. + std::unordered_map buffer_map_; // variable to be replaced Var var_; // the lanes. @@ -312,15 +371,24 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->index); - PrimExpr pred = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && pred.same_as(op->predicate)) { - return GetRef(op); - } else { - int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); - return Load(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + // BufferLoad + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = GetRef(op); + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + Array indices = op->indices; + indices.MutateByApply(fmutate); + + if (!indices.same_as(op->indices)) { + auto writer = load.CopyOnWrite(); + writer->indices = indices; + writer->LegalizeDtype(); } + + return std::move(load); } // Let PrimExpr VisitExpr_(const LetNode* op) final { @@ -352,17 +420,43 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + Array indices = op->indices; + indices.MutateByApply(fmutate); + PrimExpr value = this->VisitExpr(op->value); - PrimExpr index = this->VisitExpr(op->index); - PrimExpr pred = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && index.same_as(op->index)) { - return GetRef(op); - } else { - int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); - lanes = std::max(lanes, pred.dtype().lanes()); - return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + + if (!indices.same_as(op->indices) || !value.same_as(op->value)) { + int index_lanes = 1; + for (const auto& index : indices) { + index_lanes *= index.dtype().lanes(); + } + + int lanes = std::max(index_lanes, value.dtype().lanes()); + + int last_index_lanes = indices[indices.size() - 1].dtype().lanes(); + int earlier_index_lanes = index_lanes / last_index_lanes; + + // Broadcast the last index such that the total number of index + // lanes matches the desired number. + ICHECK_EQ(lanes % last_index_lanes, 0) + << "Cannot produce location with " << value.dtype().lanes(); + indices.Set(indices.size() - 1, + BroadcastTo(indices[indices.size() - 1], lanes / earlier_index_lanes)); + + auto writer = store.CopyOnWrite(); + writer->indices = indices; + writer->value = BroadcastTo(value, lanes); } + + return std::move(store); } // For Stmt VisitStmt_(const ForNode* op) final { @@ -429,23 +523,35 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->condition); if (condition.dtype().is_vector()) { - LOG(WARNING) << "Cannot handle vector extent in alloc "; + LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; return Scalarize(GetRef(op)); } + + // Mutate the extents Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); + for (const auto& extent : op->extents) { + PrimExpr new_ext = this->VisitExpr(extent); if (new_ext.dtype().is_vector()) { - LOG(WARNING) << "Cannot handle vector extent in alloc "; + LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; return Scalarize(GetRef(op)); } extents.push_back(new_ext); } - // place the vector lanes in least significant dimension. - extents.push_back(var_lanes_); - // rewrite access to buffer internally. + + // TODO(Lunderberg): Move this pass to be prior to + // StorageFlatten/FlattenBuffer. That will allow this pass to be + // implemented as adding a new buffer dimension, which is later + // flattened. + + // Extend the least significant dimension by a factor of + // var_lanes_. Typically, this will be a 1-d index into a flat + // memory space. + extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); + + // Rewrite access to the buffer in the body. Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); return Allocate(op->buffer_var, op->dtype, extents, condition, body); From cbc3a6bbd6669f2882f1da6a7b0f4ba8dc414c17 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 Nov 2021 11:59:25 -0600 Subject: [PATCH 009/177] Replacing Load/Store in codegens. - Device code generators - CodegenC - CodegenLLVM - CodeGenOpenCL - Utilities used during codegen - ArgBinder - MakePackedAPI - ReturnRewriter - SplitHostDevice - Execution environments - CodeGenStackVM - CodeGenHybrid - AOTExecutorCodegen --- src/contrib/hybrid/codegen_hybrid.cc | 8 +++ src/contrib/hybrid/codegen_hybrid.h | 2 + src/relay/backend/aot_executor_codegen.cc | 12 ++-- src/target/llvm/codegen_llvm.cc | 60 +++++++++++++------- src/target/llvm/codegen_llvm.h | 2 + src/target/source/codegen_c.cc | 59 ++++++++++++-------- src/target/source/codegen_c.h | 5 +- src/target/source/codegen_cuda.cc | 4 +- src/target/source/codegen_cuda.h | 3 +- src/target/source/codegen_opencl.cc | 6 +- src/target/source/codegen_opencl.h | 1 + src/target/spirv/codegen_spirv.cc | 40 ++++++++------ src/target/spirv/codegen_spirv.h | 4 +- src/target/stackvm/codegen_stackvm.cc | 34 +++++++++--- src/target/stackvm/codegen_stackvm.h | 2 + src/tir/transforms/arg_binder.cc | 34 ++++++------ src/tir/transforms/make_packed_api.cc | 67 +++++++++++++++++------ src/tir/transforms/split_host_device.cc | 32 ++++++++++- 18 files changed, 256 insertions(+), 119 deletions(-) diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54edbaee35cd..1c8cfd01d4c7 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -271,6 +271,14 @@ void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLIN void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } +void CodeGenHybrid::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Phase 0 has no BufferLoad(s)!"; +} + +void CodeGenHybrid::VisitStmt_(const BufferStoreNode* op) { + LOG(FATAL) << "Phase 0 has no BufferStore(s)!"; +} + void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Let(s)!"; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 47c13f73022f..da45ffb6a8ce 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -89,6 +89,7 @@ class CodeGenHybrid : public ExprFunctor, // expression void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override; // NOLINT(*) @@ -120,6 +121,7 @@ class CodeGenHybrid : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ProducerStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index f076efeb4ac5..dd80e54553d5 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -403,15 +403,17 @@ class AOTExecutorCodegen : public MixedModeVisitor { */ void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { // Define intermediate DLTensor to load/store the data - auto tmp0 = te::Var("tmp0", DataType::Handle()); - auto tmp1 = te::Var("tmp1", DataType::Handle()); + tir::Buffer tmp_read = + tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); + tir::Buffer tmp_write = + tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); te::Var loop_idx("i", DataType::Int(32)); - auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); + auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); // Copy the variable from the input to the output tir::Stmt copy = tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, - tir::Store(tmp1, tir::Let(tmp0, in, retval_i), loop_idx, tir::const_true())); - stmts_.push_back(tir::LetStmt(tmp1, out, copy)); + tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); + stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); } /* diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index cc6fdc31c563..4cf89df2494e 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1236,14 +1236,24 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { + LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; + return NULL; +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; + DataType t = op->dtype; - bool is_volatile = volatile_buf_.count(op->buffer_var.get()); - llvm::Value* buffer = MakeValue(op->buffer_var); - llvm::Value* index = MakeValue(op->index); + Var buffer_var = op->buffer->data; + const PrimExpr& buffer_index = op->indices[0]; + + bool is_volatile = volatile_buf_.count(buffer_var.get()); + llvm::Value* buffer = MakeValue(buffer_var); + llvm::Value* index = MakeValue(buffer_index); if (t.lanes() == 1) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); + GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, @@ -1254,14 +1264,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { #else llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(load, op->buffer_var.get(), op->index); + AddAliasInfo(load, buffer_var.get(), buffer_index); return load; } else { // vector load - if (const RampNode* ramp = op->index.as()) { + if (const RampNode* ramp = buffer_index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); + GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); @@ -1279,7 +1289,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { #else llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(load, op->buffer_var.get(), op->index); + AddAliasInfo(load, buffer_var.get(), buffer_index); return load; } } @@ -1299,9 +1309,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); - AddAliasInfo(load, op->buffer_var.get(), PrimExpr()); + AddAliasInfo(load, buffer_var.get(), PrimExpr()); }; - this->Scalarize(op->index, f); + this->Scalarize(buffer_index, f); return ret; } @@ -1366,16 +1376,24 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { } void CodeGenLLVM::VisitStmt_(const StoreNode* op) { - ICHECK(is_one(op->predicate)) << op->predicate; + LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead."; +} + +void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; + DataType t = op->value.dtype(); - bool is_volatile = volatile_buf_.count(op->buffer_var.get()); - llvm::Value* buffer = MakeValue(op->buffer_var); - llvm::Value* index = MakeValue(op->index); + Var buffer_var = op->buffer->data; + PrimExpr buffer_index = op->indices[0]; + + bool is_volatile = volatile_buf_.count(buffer_var.get()); + llvm::Value* buffer = MakeValue(buffer_var); + llvm::Value* index = MakeValue(buffer_index); llvm::Value* value = MakeValue(op->value); if (t.lanes() == 1) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); + GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = @@ -1384,14 +1402,14 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(store, op->buffer_var.get(), op->index); + AddAliasInfo(store, buffer_var.get(), buffer_index); return; } else { // vector store - if (const RampNode* ramp = op->index.as()) { + if (const RampNode* ramp = buffer_index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; - GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); + GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); @@ -1407,7 +1425,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(store, op->buffer_var.get(), op->index); + AddAliasInfo(store, buffer_var.get(), buffer_index); return; } } @@ -1425,9 +1443,9 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore( builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); #endif - AddAliasInfo(store, op->buffer_var.get(), PrimExpr()); + AddAliasInfo(store, buffer_var.get(), PrimExpr()); }; - this->Scalarize(op->index, f); + this->Scalarize(buffer_index, f); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index a40677c955f8..17a509ece005 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -171,12 +171,14 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* VisitExpr_(const SelectNode* op) override; llvm::Value* VisitExpr_(const LetNode* op) override; llvm::Value* VisitExpr_(const LoadNode* op) override; + llvm::Value* VisitExpr_(const BufferLoadNode* op) override; llvm::Value* VisitExpr_(const CallNode* op) override; llvm::Value* VisitExpr_(const RampNode* op) override; llvm::Value* VisitExpr_(const ShuffleNode* op) override; llvm::Value* VisitExpr_(const BroadcastNode* op) override; // stmt void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index e6f81646242d..214009368cda 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -649,18 +649,25 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, } void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead."; +} + +void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + int lanes = op->dtype.lanes(); // delcare type. if (op->dtype.lanes() == 1) { - std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index); + std::string ref = GetBufferRef(op->dtype, buffer_var.get(), index); HandleVolatileLoads(ref, op, os); } else { - ICHECK(is_one(op->predicate)) << "predicated load is not supported"; - bool can_vector_load = false; arith::PVar base; - if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) { - const RampNode* ramp = op->index.as(); + if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { + const RampNode* ramp = index.as(); ICHECK(ramp); arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); // The condition: {k * coeff + base} divisible by the alignment for any k @@ -670,19 +677,19 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) } if (can_vector_load) { - std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval()); + std::string ref = GetVecLoad(op->dtype, buffer_var.get(), base.Eval()); HandleVolatileLoads(ref, op, os); } else { std::ostringstream svalue_expr; - std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype()); - std::string vid = GetVarID(op->buffer_var.get()); + std::string sindex = SSAGetID(PrintExpr(index), index.dtype()); + std::string vid = GetVarID(buffer_var.get()); DataType elem_type = op->dtype.element_of(); for (int i = 0; i < lanes; ++i) { std::ostringstream value_temp; - if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { + if (!HandleTypeMatch(buffer_var.get(), elem_type)) { value_temp << "(("; - if (op->buffer_var.get()->dtype.is_handle()) { - auto it = alloc_storage_scope_.find(op->buffer_var.get()); + if (buffer_var.get()->dtype.is_handle()) { + auto it = alloc_storage_scope_.find(buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, value_temp); } @@ -693,7 +700,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) value_temp << vid; } value_temp << '['; - PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp); + PrintVecElemLoad(sindex, index.dtype(), i, value_temp); value_temp << ']'; PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); } @@ -703,35 +710,43 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead."; +} + +void CodeGenC::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + DataType t = op->value.dtype(); + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + if (t.lanes() == 1) { std::string value = this->PrintExpr(op->value); - std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); + std::string ref = this->GetBufferRef(t, buffer_var.get(), index_expr); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } else { - ICHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; - if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { + if (arith::ramp(base, 1, t.lanes()).Match(index_expr)) { std::string value = this->PrintExpr(op->value); - this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); + this->PrintVecStore(buffer_var.get(), t, base.Eval(), value); } else { // The assignment below introduces side-effect, and the resulting value cannot // be reused across multiple expression, thus a new scope is needed int vec_scope = BeginScope(); // store elements seperately - std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype()); + std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype()); std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); - std::string vid = GetVarID(op->buffer_var.get()); + std::string vid = GetVarID(buffer_var.get()); for (int i = 0; i < t.lanes(); ++i) { this->PrintIndent(); DataType elem_type = t.element_of(); - if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { + if (!HandleTypeMatch(buffer_var.get(), elem_type)) { stream << "(("; - if (op->buffer_var.get()->dtype.is_handle()) { - auto it = alloc_storage_scope_.find(op->buffer_var.get()); + if (buffer_var.get()->dtype.is_handle()) { + auto it = alloc_storage_scope_.find(buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); } @@ -742,7 +757,7 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { stream << vid; } stream << '['; - PrintVecElemLoad(index, op->index.dtype(), i, stream); + PrintVecElemLoad(index, index_expr.dtype(), i, stream); stream << "] = "; PrintVecElemLoad(value, op->value.dtype(), i, stream); stream << ";\n"; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 3b042b9fbd2c..3536b74b5636 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -126,6 +126,7 @@ class CodeGenC : public ExprFunctor, // expression void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) @@ -155,6 +156,7 @@ class CodeGenC : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; @@ -206,7 +208,8 @@ class CodeGenC : public ExprFunctor, * does not implement volatile member functions. CUDA codegen will cast * away volatile qualifier from CUDA __half types. */ - virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) { + virtual void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, + std::ostream& os) { // By default, do nothing but print the loaded value. os << value; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a1f257391db4..28d972232f5f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1106,12 +1106,12 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode return 0; } -void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op, +void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, std::ostream& os) { // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) { + if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { os << "("; PrintType(op->dtype, os); os << ")(" << value << ")"; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 385b7343c8fd..673753c470ae 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -76,7 +76,8 @@ class CodeGenCUDA final : public CodeGenC { private: // Handle volatile loads - void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; + void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, + std::ostream& os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. bool IsScopePartOfType() const final { return false; } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 507a6243cb0c..6e39306be11e 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -337,13 +337,17 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType } void CodeGenOpenCL::VisitStmt_(const StoreNode* op) { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) { if (auto call = op->value.as()) { if (call->op.same_as(builtin::texture2d_load())) { need_texture_ssa_ = false; // If storing a texture load into a buffer, don't use an // intermediate local unless the buffer allocation is a // single element selected from the texture read. - auto it = allocation_size_.find(op->buffer_var.get()); + auto it = allocation_size_.find(op->buffer->data.get()); if (it != allocation_size_.end() && it->second == 1) { need_texture_ssa_ = true; } diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 8c36a817753c..c72875e8561f 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -64,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const StoreNode* op) final; // NOLINT(*) + void VisitStmt_(const BufferStoreNode* op) final; // NOLINT(*) // overload min and max to avoid ambiguous call errors void VisitExpr_(const MinNode* op, std::ostream& os) final; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 66952dae269e..bbe5de1fa58d 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -412,22 +412,23 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { return builder_->Concat(values); } -spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { - ICHECK(is_one(op->predicate)); +spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; + Var buffer_var = op->buffer->data; + PrimExpr prim_index = op->indices[0]; DataType desired_read_type = op->dtype; if (desired_read_type == DataType::Bool()) { desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); } - const VarNode* buffer_var = op->buffer_var.get(); - auto it = storage_info_.find(buffer_var); + auto it = storage_info_.find(buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(desired_read_type, op->index.dtype().lanes()); + info.CheckContentType(desired_read_type, prim_index.dtype().lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); - spirv::Value buffer = MakeValue(op->buffer_var); + spirv::Value buffer = MakeValue(buffer_var); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; @@ -438,7 +439,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (desired_read_type == info.element_type) { // Requested a single value from an array. This may be a scalar load // or a vectorized load, based on the array element type. - spirv::Value index = MakeValue(op->index); + spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); // OpTypeBool have no physical address/storage. Here, cast from @@ -457,13 +458,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); }; - this->Scalarize(op->index, f); + this->Scalarize(prim_index, f); return builder_->Concat(values); } else { LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint << "' with element type " << info.element_type << " using index of type " - << op->index->dtype << " to produce output of type " << op->dtype; + << prim_index->dtype << " to produce output of type " << op->dtype; return spirv::Value(); } } @@ -483,15 +484,18 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::functionpredicate)); - auto it = storage_info_.find(op->buffer_var.get()); +void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; + Var buffer_var = op->buffer->data; + PrimExpr prim_index = op->indices[0]; + + auto it = storage_info_.find(buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(op->value.dtype(), op->index.dtype().lanes()); + info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); - spirv::Value buffer = MakeValue(op->buffer_var); + spirv::Value buffer = MakeValue(buffer_var); spirv::Value value = MakeValue(op->value); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -505,7 +509,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { // or a vectorized store, based on the array element type. ICHECK_EQ(info.element_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; - spirv::Value index = MakeValue(op->index); + spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); @@ -517,12 +521,12 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, elem, mask); }; - this->Scalarize(op->index, f); + this->Scalarize(prim_index, f); } else { LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '" - << op->buffer_var->name_hint << "' with element type " << info.element_type - << " using index of type " << op->index->dtype; + << buffer_var->name_hint << "' with element type " << info.element_type + << " using index of type " << prim_index->dtype; } } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 74b62e7613d1..08b9db0ee539 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -100,9 +100,9 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value VisitExpr_(const CallNode* op) override; spirv::Value VisitExpr_(const RampNode* op) override; spirv::Value VisitExpr_(const BroadcastNode* op) override; - spirv::Value VisitExpr_(const LoadNode* op) override; + spirv::Value VisitExpr_(const BufferLoadNode* op) override; // stmt - void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 402e3291975f..e93b01becabe 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -140,12 +140,21 @@ int CodeGenStackVM::GetVarID(const VarNode* v) const { } void CodeGenStackVM::VisitExpr_(const LoadNode* op) { - this->Push(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; +} + +void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + auto index = op->indices[0]; + + this->Push(op->buffer->data); StackVM::OpCode code = StackVM::GetLoad(op->dtype); - if (const IntImmNode* index = op->index.as()) { - this->PushOp(code, index->value); + if (const IntImmNode* int_index = index.as()) { + this->PushOp(code, int_index->value); } else { - this->Push(op->index); + this->Push(index); this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); @@ -154,13 +163,22 @@ void CodeGenStackVM::VisitExpr_(const LoadNode* op) { } void CodeGenStackVM::VisitStmt_(const StoreNode* op) { - this->Push(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; +} + +void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " + << "Has StorageFlatten (TE-based schedules) or " + << "FlattenBuffer (TIR-based schedules) been run?"; + auto index = op->indices[0]; + + this->Push(op->buffer->data); StackVM::OpCode code = StackVM::GetStore(op->value.dtype()); - if (const IntImmNode* index = op->index.as()) { + if (const IntImmNode* int_index = index.as()) { this->Push(op->value); - this->PushOp(code, index->value); + this->PushOp(code, int_index->value); } else { - this->Push(op->index); + this->Push(index); this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 480ffc7eb870..ae6f316b475d 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -108,6 +108,7 @@ class CodeGenStackVM : public ExprFunctor, // expression void VisitExpr_(const VarNode* op) final; void VisitExpr_(const LoadNode* op) final; + void VisitExpr_(const BufferLoadNode* op) final; void VisitExpr_(const LetNode* op) final; void VisitExpr_(const CallNode* op) final; void VisitExpr_(const AddNode* op) final; @@ -136,6 +137,7 @@ class CodeGenStackVM : public ExprFunctor, // statment void VisitStmt_(const LetStmtNode* op) final; void VisitStmt_(const StoreNode* op) final; + void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; void VisitStmt_(const AllocateNode* op) final; diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index d3ab32cbd7f9..013297c2550c 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -184,10 +184,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::Int(32), buffer->data_alignment), nop)); } + Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type, + arg_name + ".shape"); Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( - LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); + LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; @@ -195,16 +197,16 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, std::ostringstream field_name; field_name << v_shape->name_hint << '[' << k << ']'; Bind_(buffer->shape[k], - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), + cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), field_name.str(), true); } // strides field - Var v_strides(arg_name + ".strides", DataType::Handle()); - def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); - init_nest_.emplace_back( - LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); - PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}); + Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, + tvm_shape_type, arg_name + ".strides"); + def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); + init_nest_.emplace_back(LetStmt( + buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); + PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -212,8 +214,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = - cast(stype, Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -235,10 +236,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; std::ostringstream field_name; - field_name << v_strides->name_hint << '[' << k << ']'; + field_name << buf_strides->name << '[' << k << ']'; PrimExpr value = - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); value = tvm::if_then_else(v_strides_is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -249,19 +249,17 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, for (int k = buffer->strides.size() - 1; k >= 0; k--) { std::ostringstream field_name; - field_name << v_strides->name_hint << '[' << k << ']'; + field_name << buf_strides->name << '[' << k << ']'; PrimExpr explicit_stride = - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); Bind_(buffer->strides[k], tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), field_name.str(), true); stride_from_shape *= - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))); + cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); } } // Byte_offset field. diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d7e1beff03d3..8d8020f4d06c 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -61,34 +61,63 @@ class ReturnRewriter : public StmtMutator { if (call->op.same_as(builtin::ret())) { ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; - ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); + ret = WriteToOut(call->args[0]); } } return ret; } private: - std::pair ConvertForFFI(PrimExpr val) { + struct ConvertedInfo { + int tcode{-1}; + PrimExpr expr; + Buffer dummy_val_buffer; + Buffer dummy_tcode_buffer; + }; + + ConvertedInfo ConvertForFFI(PrimExpr val) { + ConvertedInfo info; + // convert val's data type to FFI data type, return type code DataType dtype = val.dtype(); if (dtype.is_int() || dtype.is_uint()) { - return {kTVMArgInt, Cast(DataType::Int(64), val)}; + info.tcode = kTVMArgInt; + info.expr = Cast(DataType::Int(64), val); } else if (dtype.is_float()) { - return {kTVMArgFloat, Cast(DataType::Float(64), val)}; + info.tcode = kTVMArgFloat; + info.expr = Cast(DataType::Float(64), val); } else if (dtype.is_void()) { - return {kTVMNullptr, val}; + info.tcode = kTVMNullptr; + info.expr = val; } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } - return {kTVMNullptr, val}; + + // If multiple return locations have the same data type, use the + // same dummy buffer declaration. + auto it = dummy_val_buffer_map_.find(info.tcode); + if (it != dummy_val_buffer_map_.end()) { + info.dummy_val_buffer = it->second; + } else { + info.dummy_val_buffer = + Buffer(ret_var_, dtype, {1}, {1}, ConstInt32(0), ret_var_->name_hint, 0, 0, kDefault); + dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer; + } + + // The tcode is always a 32-bit int, so we don't need to have a separate map. + if (!dummy_tcode_buffer_.defined()) { + dummy_tcode_buffer_ = Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0), + ret_tcode_->name_hint, 0, 0, kDefault); + } + info.dummy_tcode_buffer = dummy_tcode_buffer_; + + return info; } - Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) { - auto p = ConvertForFFI(val); - int tcode = p.first; - val = p.second; - Stmt store_val = Store(ret_var_, val, 0, const_true()); - Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true()); + Stmt WriteToOut(PrimExpr val) { + auto info = ConvertForFFI(val); + Stmt store_val = BufferStore(info.dummy_val_buffer, val, {0}); + Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0}); Stmt ret_zero = Evaluate(tvm::ret(0)); return SeqStmt({store_val, store_tcode, ret_zero}); } @@ -96,6 +125,9 @@ class ReturnRewriter : public StmtMutator { Var ret_var_; Var ret_tcode_; int in_parallel_{0}; + + std::unordered_map dummy_val_buffer_map_; + Buffer dummy_tcode_buffer_; }; Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { @@ -131,7 +163,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // Data field definitions // The packed fields Var v_packed_args("args", DataType::Handle()); - Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle()); + Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())}, + DataType::Int(32), "arg_type_ids"); Var v_num_packed_args("num_args", DataType::Int(32)); Var v_out_ret_value("out_ret_value", DataType::Handle()); Var v_out_ret_tcode("out_ret_tcode", DataType::Handle()); @@ -166,7 +199,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // add signature for packed arguments. if (pack_args) { args.push_back(v_packed_args); - args.push_back(v_packed_arg_type_ids); + args.push_back(buf_packed_arg_type_ids->data); args.push_back(v_num_packed_args); } @@ -196,10 +229,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmt(tcode, - Load(DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back( + LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 7f2ecf54dfcb..4274733f095e 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -107,7 +107,12 @@ class VarUseDefAnalysis : public StmtExprMutator { } Stmt VisitStmt_(const StoreNode* op) final { - this->HandleUse(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + this->HandleUse(op->buffer->data); return StmtExprMutator::VisitStmt_(op); } @@ -155,10 +160,33 @@ class VarUseDefAnalysis : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { - this->HandleUse(op->buffer_var); + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + this->HandleUse(op->buffer->data); return StmtExprMutator::VisitExpr_(op); } + void VisitBuffer(Buffer buffer) { + this->HandleUse(buffer->data); + auto visit_arr = [&](Array arr) { + for (const auto& element : arr) { + this->VisitExpr(element); + } + }; + + visit_arr(buffer->shape); + visit_arr(buffer->strides); + if (buffer->pre_flattened_shape) { + visit_arr(buffer->pre_flattened_shape.value()); + } + if (buffer->pre_flattened_strides) { + visit_arr(buffer->pre_flattened_strides.value()); + } + } + void HandleDef(const VarNode* v) { ICHECK(!def_count_.count(v)) << "variable " << v->name_hint << " has already been defined, the Stmt is not SSA"; From 7399589b45d47fdd9f2d1f4d62f3e3c525d47b17 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Oct 2021 14:22:12 -0500 Subject: [PATCH 010/177] [UnitTest] Add unit tests to test physical layout remapping. --- .../python/unittest/test_transform_layout.py | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100755 tests/python/unittest/test_transform_layout.py diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py new file mode 100755 index 000000000000..5b343cd4c8dd --- /dev/null +++ b/tests/python/unittest/test_transform_layout.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import functools +import sys +import pytest + +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.tir.stmt_functor import post_order_visit +from tvm.driver.build_module import schedule_to_module + +dtype = tvm.testing.parameter("int32") + + +def flatten_all_indices(preflatten_shape): + def mapping(*indices): + output = 0 + for index, size in zip(indices, preflatten_shape): + output = output * size + index + return [output] + + return mapping + + +def unpack_flattened_indices(preflatten_shape): + def mapping(i): + output = [] + for dim in reversed(preflatten_shape): + output.append(i % dim) + i //= dim + return output[::-1] + + return mapping + + +def traverse(s, op, callback): + visited = set() + + def _traverse(op): + if op in visited: + return + visited.add(op) + for tensor in op.input_tensors: + _traverse(tensor.op) + callback(op) + + _traverse(op) + + +class TestCompareAgainstExplicitReshape: + A_definition_style = tvm.testing.parameter( + "explicit_reshape", + "transform_layout", + ) + B_definition_style = tvm.testing.parameter( + "explicit_reshape", + "transform_layout", + ) + + reordered_shape = tvm.testing.parameter((2, 3, 4)) + + @tvm.testing.fixture + def n_items(self, reordered_shape): + return functools.reduce(lambda x, y: x * y, reordered_shape, 1) + + @tvm.testing.fixture + def fphysical_layout(self, reordered_shape): + return unpack_flattened_indices(reordered_shape) + + @tvm.testing.fixture + def fcompute(self, A_definition_style, B_definition_style, reordered_shape, n_items, dtype): + assert A_definition_style in ["explicit_reshape", "transform_layout"] + assert B_definition_style in ["explicit_reshape", "transform_layout"] + + def func(): + if A_definition_style == "explicit_reshape": + A_input = te.placeholder(shape=reordered_shape, name="A_input", dtype=dtype) + A = te.compute( + shape=(n_items,), + fcompute=lambda i: A_input[ + i // (reordered_shape[1] * reordered_shape[2]), + (i // reordered_shape[2]) % reordered_shape[1], + i % reordered_shape[2], + ], + name="A", + ) + + elif A_definition_style == "transform_layout": + A = te.placeholder(shape=(n_items,), name="A", dtype=dtype) + A_input = A + + B = te.compute(shape=A.shape, fcompute=lambda i: A[i], name="B") + + if B_definition_style == "explicit_reshape": + B_output = te.compute( + shape=reordered_shape, + fcompute=lambda i, j, k: B[ + i * reordered_shape[1] * reordered_shape[2] + j * reordered_shape[2] + k + ], + name="B_output", + ) + elif B_definition_style == "transform_layout": + B_output = B + + return A_input, B_output + + return func + + @tvm.testing.fixture + def fschedule(self, A_definition_style, B_definition_style, fphysical_layout): + def func(outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def callback(op): + if (op.name == "A" and A_definition_style == "transform_layout") or ( + op.name == "B" and B_definition_style == "transform_layout" + ): + s[op].transform_layout(fphysical_layout) + + traverse(s, outs[0].op, callback) + return s + + return func + + @tvm.testing.parametrize_targets("llvm") + def test_external_reshape( + self, target, dev, fcompute, fschedule, n_items, reordered_shape, dtype + ): + A, B = fcompute() + s = fschedule(B) + + func = tvm.build(s, [A, B], target=target, name="copy_reshape") + + a_np = np.arange(n_items).reshape(reordered_shape).astype(dtype) + b_np = np.arange(n_items).reshape(reordered_shape).astype(dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.empty(b_np.shape, dtype=dtype, device=dev) + + func(a, b) + + tvm.testing.assert_allclose(b.numpy(), b_np) + + @tvm.testing.parametrize_targets("llvm") + def test_internal_reshape(self, target, dev, n_items, reordered_shape, dtype, fphysical_layout): + # The reshaping of the buffer gets flattened away in + # StorageFlatten. Therefore, testing the behavior by running only + # ApplyLayoutTransforms. + logical_shape = (n_items,) + A = te.placeholder(logical_shape, name="A", dtype=dtype) + B = te.compute(shape=logical_shape, fcompute=lambda i: A[i], name="B") + C = te.compute(shape=logical_shape, fcompute=lambda i: B[i], name="C") + + s = te.create_schedule(C.op) + s[B].transform_layout(fphysical_layout) + + mod = schedule_to_module(s, [A, C]) + body = mod["main"].body + + def walk_buffer_interactions(stmt, callback): + buffer_classes = [ + tvm.tir.BufferLoad, + tvm.tir.BufferStore, + tvm.tir.BufferRealize, + ] + + def inner(node): + if (type(node) in buffer_classes) and node.buffer.name == "B": + callback(node) + + post_order_visit(stmt, inner) + + # All references to the buffer are the same object + def check_references(): + buffer_object = None + + def inner(node): + nonlocal buffer_object + if buffer_object is None: + buffer_object = node.buffer + else: + assert node.buffer.same_as(buffer_object) + + return inner + + # The buffer has the expected shape. + def check_shape(expected_shape): + def inner(node): + assert tuple(node.buffer.shape) == expected_shape + + return inner + + # Before the transform, the buffer should be in the logical shape. + walk_buffer_interactions(body, check_references()) + walk_buffer_interactions(body, check_shape(logical_shape)) + + mod = tvm.tir.transform.ApplyLayoutTransforms()(mod) + body = mod["main"].body + + # After the transform, the buffer should be in the physical shape. + walk_buffer_interactions(body, check_references()) + walk_buffer_interactions(body, check_shape(reordered_shape)) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) From 24259f01ff28c9a6823787c344e81d7128e965a1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 Nov 2021 15:43:12 -0600 Subject: [PATCH 011/177] Updated tvm::address_of() to hold BufferLoad instead of Load. --- include/tvm/tir/builtin.h | 6 +++--- src/target/llvm/codegen_llvm.cc | 16 ++++++++-------- src/target/source/codegen_c.cc | 15 ++++++++------- src/target/source/codegen_opencl.cc | 9 +++++---- src/target/stackvm/codegen_stackvm.cc | 12 +++++++----- src/tir/transforms/ir_utils.h | 16 +++++++++++----- .../merge_dynamic_shared_memory_allocations.cc | 6 ++++-- src/tir/transforms/rewrite_unsafe_select.cc | 9 +++++++-- src/tir/transforms/storage_access.cc | 4 ++-- src/tir/transforms/storage_rewrite.cc | 6 ++++-- 10 files changed, 59 insertions(+), 40 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e6665f965b5b..7bffa41f67d6 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -105,10 +105,10 @@ TVM_DLL const Op& large_uint_imm(); TVM_DLL const Op& q_multiply_shift(); /*! - * \brief See pesudo code + * \brief See pseudo code * - * Handle address_of(Load *op) { - * return &op->buffer_var[index]; + * Handle address_of(BufferLoad *op) { + * return &op->buffer_var[op->indices[0], op->indices[1], ..., op->indices[N-1]]; * } */ TVM_DLL const Op& address_of(); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 4cf89df2494e..44cf72153dc0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -976,15 +976,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_storage_sync())) { return CreateStorageSync(op); } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - ICHECK(op->args.size() == 1 && l); - TypedPointer buffer_ptr; - if (const RampNode* r = l->index.as()) { - PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - } else { - buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations."; + PrimExpr index = load->indices[0]; + if (const RampNode* r = index.as()) { + index = r->base / make_const(DataType::Int(32), r->lanes); } + TypedPointer buffer_ptr = + CreateBufferPtr(load->dtype, MakeValue(load->buffer->data), MakeValue(index)); unsigned addrspace = llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 214009368cda..9ba321808882 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -585,15 +585,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[2], os); os << ")"; } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - ICHECK(op->args.size() == 1 && l); + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 0) << "CodeGenC only supports flat memory allocations."; os << "(("; - this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + " + this->PrintType(load->dtype.element_of(), os); + os << " *)" << this->GetVarID(load->buffer->data.get()) << " + " << "("; - this->PrintExpr(l->index, os); - if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int())) { - os << " / " << (32 / l->dtype.bits()); + this->PrintExpr(load->indices[0], os); + if (load->dtype.bits() == 4 || (load->dtype.bits() == 1 && load->dtype.is_int())) { + os << " / " << (32 / load->dtype.bits()); } os << "))"; } else if (op->op.same_as(builtin::tvm_struct_get())) { diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 6e39306be11e..28277077179f 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -376,16 +376,17 @@ void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->op.same_as(builtin::address_of())) { // Overload tvm_address_of to add storage scope (e.g. __global). - const LoadNode* load = op->args[0].as(); + const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 0) << "CodeGenOpenCL only supports flat memory allocations."; os << "(("; - auto it = alloc_storage_scope_.find(load->buffer_var.get()); + auto it = alloc_storage_scope_.find(load->buffer->data.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } this->PrintType(load->dtype.element_of(), os); - os << " *)" << this->GetVarID(load->buffer_var.get()) << " + "; - this->PrintExpr(load->index, os); + os << " *)" << this->GetVarID(load->buffer->data.get()) << " + "; + this->PrintExpr(load->indices[0], os); os << ')'; } else if (op->op.same_as(builtin::texture2d_store())) { auto* ptr_type = op->args[0].as()->type_annotation.as(); diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index e93b01becabe..e70405445349 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -193,11 +193,13 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - ICHECK(op->args.size() == 1 && l); - this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); - this->Push(l->index); - this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes()); + const BufferLoadNode* load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 0) << "CodeGenStackVM only supports flat memory allocations."; + + this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get())); + this->Push(load->indices[0]); + this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); } else if (op->op.same_as(builtin::reinterpret())) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index da52a82a2f08..55c90d25bdd6 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -103,9 +103,11 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return Call(DataType::Handle(), builtin::address_of(), - {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}); + PrimExpr offset_expr = make_const(DataType::Int(32), offset * dtype.lanes()); + Buffer dummy_buf(handle, dtype, {offset_expr + 1}, {}, 0, handle->name_hint, 0, 0, kDefault); + BufferLoad buf_load(dummy_buf, {offset_expr}); + + return Call(DataType::Handle(), builtin::address_of(), {buf_load}); } /*! @@ -119,8 +121,12 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return Call(DataType::Handle(), builtin::address_of(), - {Load(dtype, handle, offset, const_true(dtype.lanes()))}); + + Buffer dummy_buf(handle, dtype.element_of(), {offset + 1}, {}, 0, handle->name_hint, 0, 0, + kDefault); + BufferLoad buf_load(dummy_buf, {offset}); + + return Call(DataType::Handle(), builtin::address_of(), {buf_load}); } /*! diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 2917ef218f0b..6259aefb3bad 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -159,8 +159,10 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - this->VisitExpr(l->index); + const BufferLoadNode* load = op->args[0].as(); + for (const auto& index : load->indices) { + this->VisitExpr(index); + } } else { StmtExprVisitor::VisitExpr_(op); } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 1ce54846aaea..8a37f9958073 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -42,8 +42,13 @@ class UnsafeExprDetector : public ExprFunctor { if (op->op.same_as(builtin::if_then_else())) { return VisitExpr(op->args[0]); } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - return this->VisitExpr(l->index); + const BufferLoadNode* load = op->args[0].as(); + for (const auto& index : load->indices) { + if (VisitExpr(index)) { + return true; + } + } + return false; } else if (auto* ptr_op = op->op.as()) { auto effect_kind = op_call_effect_[GetRef(ptr_op)]; if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) { diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 025233df56a7..4f19f708880c 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -213,8 +213,8 @@ void StorageAccessVisitor::VisitStmt_(const WhileNode* op) { void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - StmtExprVisitor::VisitExpr_(l); + const BufferLoadNode* load = op->args[0].as(); + StmtExprVisitor::VisitExpr_(load); } else if (op->op.same_as(builtin::tvm_access_ptr())) { ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index b1cb35341840..a07dda7ae1ad 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -142,8 +142,10 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - this->VisitExpr(l->index); + const BufferLoadNode* load = op->args[0].as(); + for (const auto& index : load->indices) { + this->VisitExpr(index); + } } else { StmtExprVisitor::VisitExpr_(op); } From 8e9ec52ad59d4036aa1a1e1d47e1b520510f3700 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 8 Oct 2021 08:52:45 -0500 Subject: [PATCH 012/177] [TIR] Added IndexMap class. Holds a set of variables representing the input indices and expressions in terms of those input indices. TODO: - Add validation, the index mapping should be invertible. - Add helper function, apply mapping to a set of indices. - Add helper function, apply mapping to bounds of input indices. --- include/tvm/tir/index_map.h | 140 ++++++++++++++++++++++++++++++++ src/tir/ir/index_map.cc | 154 ++++++++++++++++++++++++++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 include/tvm/tir/index_map.h create mode 100644 src/tir/ir/index_map.cc diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h new file mode 100644 index 000000000000..237111306c2a --- /dev/null +++ b/include/tvm/tir/index_map.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/index_map.h + * \brief Defines a remapping of buffer indices + * + * For use with tvm::tir::Buffer. + */ +#ifndef TVM_TIR_INDEX_MAP_H_ +#define TVM_TIR_INDEX_MAP_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Defines a mapping between two representations of indices + * into a buffer. + * + * This is primarily used for layout transformations of Buffer + * objects. + */ +class IndexMapNode : public Object { + public: + /*! \brief Variables representing the indices prior to remapping. + * + * If initial_indices is empty, then final_indices should also be + * empty, and no mapping is applied. + */ + Array initial_indices; + + /*! + * \brief Expressions defining the indices after remapping. + * + * These expressions should only be in terms of the initial_indices, + * and must be expressible as an IterSumExpr. The mapping from + * initial_indices to final_indices must be injective. + * + * If final_indices is empty, then initial_indices should also be + * empty, and the map is an identity function. + */ + Array final_indices; + + /*! + * \brief Default constructor + * + * Defines the mapping as an identity function, with initial_indices + * equal to the final indices. + */ + IndexMapNode() {} + + /*! + * \brief Map indices to the output space + * + * \param indices The indices in the input space. Should contain + * one value for each variable in `initial_indices`. + * + * \returns The indices in the output space. Contains one value for + * each expression in `final_indices`. + */ + Array MapIndices(const Array& indices) const; + + /*! \brief Map a memory range to the output space + * + * If contiguous memory locations in the input space are not + * necessarily contiguous in the output space (e.g. `lambda i: + * [8*(i%8) + (i//8)]`), then this will return the smallest range + * such that all valid indices are contained within the given range. + * + * \param ranges The ranges in the input space. Should contain one + * value for each variable in `initial_indices`. + * + * \returns The ranges in the output space. Contains one value for + * each expression in `final_indices`. + */ + Array MapRanges(const Array& ranges) const; + + /*! \brief Map a buffer shape to the output space + * + * \param shape The buffer shape in the input space. Should contain + * one value for each variable in `initial_indices`. + * + * \returns The buffer shape in the output space. Contains one + * value for each expression in `final_indices`. + */ + Array MapShape(const Array& shape) const; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("initial_indices", &initial_indices); + v->Visit("final_indices", &final_indices); + } + + TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); +}; + +class IndexMap : public ObjectRef { + public: + IndexMap(Array initial_indices, Array final_indices); + + /*! \brief Generate the inverse mapping. + * + * The range of the input indices is required in order to ensure + * that the transformation is bijective over the input domain. + * + * TODO(Lunderberg): Look into allowing non-bijective + * transformations. If injective, the inverse mapping could still + * be generated with some predicate. If non-injective, could + * simplify the implementation of other optimizations (e.g. double + * buffering as a map `lambda *indices: [buffer_loop%2, *indices]`). + */ + IndexMap Inverse(Array initial_ranges) const; + + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_INDEX_MAP_H_ diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc new file mode 100644 index 000000000000..ba0998e84ffc --- /dev/null +++ b/src/tir/ir/index_map.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index_map.cc + */ + +#include "tvm/tir/index_map.h" + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +IndexMap::IndexMap(Array initial_indices, Array final_indices) { + auto n = make_object(); + n->initial_indices = std::move(initial_indices); + n->final_indices = std::move(final_indices); + data_ = std::move(n); +} + +IndexMap IndexMap::Inverse(Array initial_ranges) const { + // Dummy variables to represent the inverse's inputs. + Array output_vars; + for (size_t i = 0; i < (*this)->final_indices.size(); i++) { + PrimExpr index = (*this)->final_indices[i]; + // TODO(Lunderberg): Better names for these variables. A variable + // that is passed through unmodified (`index` is an element of + // `initial_indices`) should use that input index's name. A pair + // of output indices variables split from a single input index + // should be named (X.outer,X.inner). + std::stringstream ss; + ss << "axis" << i; + Var var_index(ss.str(), index.dtype()); + output_vars.push_back(var_index); + } + + // Dummy ranges for the extent of each input. + Map input_iters; + ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size()); + for (size_t i = 0; i < initial_ranges.size(); i++) { + input_iters.Set((*this)->initial_indices[i], initial_ranges[i]); + } + + // Unpack the output indices into linear combinations of the initial + // indices. + arith::Analyzer analyzer; + auto diagnostics = DiagnosticContext::Default(IRModule()); + auto iter_map = + DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer, diagnostics); + CHECK(iter_map.size()) << "Index transformation was not bijective."; + + // Determine expressions for the input variables, in terms of the + // output variables. + Map inverse_exprs_map = + InverseAffineIterMap(iter_map, Array(output_vars.begin(), output_vars.end())); + + // Unpack the map to an array, maintaining the same parameter order. + Array inverse_exprs; + for (const auto& index : (*this)->initial_indices) { + inverse_exprs.push_back(inverse_exprs_map.at(index)); + } + + return IndexMap(output_vars, inverse_exprs); +} + +Array IndexMapNode::MapIndices(const Array& indices) const { + ICHECK_EQ(indices.size(), initial_indices.size()); + + arith::Analyzer analyzer; + + for (size_t i = 0; i < initial_indices.size(); i++) { + analyzer.Bind(initial_indices[i], indices[i]); + } + + Array output; + for (const auto& output_dim : final_indices) { + output.push_back(analyzer.Simplify(output_dim)); + } + + return output; +} + +Array IndexMapNode::MapRanges(const Array& ranges) const { + ICHECK_EQ(ranges.size(), initial_indices.size()); + + Map input_iters; + for (size_t i = 0; i < initial_indices.size(); i++) { + input_iters.Set(initial_indices[i], ranges[i]); + } + + std::unordered_map dom_map; + for (size_t i = 0; i < initial_indices.size(); i++) { + dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]); + } + + Array output; + for (const auto& final_index : final_indices) { + auto int_set = arith::EvalSet(final_index, dom_map); + output.push_back(Range::FromMinExtent(int_set.min(), int_set.max() - int_set.min() + 1)); + } + + return output; +} + +Array IndexMapNode::MapShape(const Array& shape) const { + ICHECK_EQ(shape.size(), initial_indices.size()); + + Array ranges; + for (auto& dim : shape) { + ranges.push_back(Range(0, dim)); + } + Array mapped = MapRanges(std::move(ranges)); + + Array output; + for (auto& range : mapped) { + ICHECK(is_zero(range->min)); + output.push_back(range->extent); + } + + return output; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "index_map(" << op->initial_indices << ", " << op->final_indices << ")"; + }); + +TVM_REGISTER_NODE_TYPE(IndexMapNode); + +} // namespace tir +} // namespace tvm From 0a26801ddbb371408b3edf304fd5a9da385b84cb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 12 Nov 2021 12:37:38 -0600 Subject: [PATCH 013/177] Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects. StorageFlatten/FlattenBuffer passes updated to modify the buffer/indices directly, rather than using vload/vstore. - Primary purpose of vstore/vload is to allow IR written in python to define vectorized load/store. This usage is maintained by returning a BufferLoad/BufferStore node whose index is a Ramp. - Previously, vstore/vload was also used to compute the 1-d physical index of a location within a N-d tensor. This usage will no longer be allowed, as it would not allow layout transformations to be performed after a schedule definition, but any uses of the buffer are flattened. --- include/tvm/tir/buffer.h | 7 +- src/tir/ir/buffer.cc | 144 +++++++++++++++-------- src/tir/transforms/flatten_buffer.cc | 97 ++++++++++++--- src/tir/transforms/lower_match_buffer.cc | 17 ++- src/tir/transforms/storage_flatten.cc | 92 ++++++++++----- src/tir/transforms/storage_rewrite.cc | 12 +- 6 files changed, 260 insertions(+), 109 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 69453e23ac1a..abdaf6816a14 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -127,7 +127,7 @@ class BufferNode : public Object { * without adjusting for number of lanes. (e.g. The number of * float16x4 elements in a buffer of type float16x4.) */ - PrimExpr ElemOffset(Array index) const; + Array ElemOffset(Array index) const; static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -186,6 +186,11 @@ class Buffer : public ObjectRef { */ TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + /*! + * \brief Get a flattened version of the buffer + */ + Buffer GetFlattenedBuffer() const; + /*! * \brief Return the storage scope associated with this buffer. */ diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 24aacc3c04f7..a30615598658 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -246,79 +246,120 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. -PrimExpr BufferNode::ElemOffset(Array index) const { - PrimExpr base = this->elem_offset; +Array BufferNode::ElemOffset(Array input_indices) const { + ICHECK_EQ(shape.size(), input_indices.size()) + << "Dimensionality of buffer must match dimensionality of index used to access it"; + + if (strides.size()) { + ICHECK_EQ(this->strides.size(), input_indices.size()) + << "If strides are defined, " + << "the index's dimensionality must match the dimensionality of the index given."; + } + + PrimExpr output_index = 0; + arith::Analyzer ana; - if (this->strides.size() == 0) { - // Scalar case - if (this->shape.size() == 0 && index.size() == 1) { - auto is_int = index[0].as(); - ICHECK(is_int && is_int->value == 0); - base = base + index[0]; - } else { - ICHECK_EQ(this->shape.size(), index.size()); - if (index.size() > 0) { - PrimExpr offset = index[0]; - for (size_t i = 1; i < index.size(); ++i) { - offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]); - } - base = base + offset; - } - } - } else { - ICHECK_EQ(this->strides.size(), index.size()); - if (is_zero(base)) { - base = MergeMulMod(&ana, index[0] * this->strides[0]); + + for (size_t i = 0; i < input_indices.size(); i++) { + if (strides.size()) { + output_index = output_index + input_indices[i] * strides[i]; } else { - base = MergeMulMod(&ana, base + index[0] * this->strides[0]); + output_index = output_index * this->shape[i] + input_indices[i]; } - for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(&ana, base + index[i] * this->strides[i]); + + if (i > 0) { + output_index = MergeMulMod(&ana, output_index); } } - return base; + + if (elem_offset.defined() && !is_zero(elem_offset)) { + output_index = output_index + elem_offset; + } + + return {output_index}; } -inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataType dtype) { - PrimExpr offset = n->ElemOffset(index); +inline Array BufferOffset(const BufferNode* n, Array index, DataType dtype) { + Array offsets = n->ElemOffset(index); + // If the Buffer has element type with more than one lane, scale to + // get the offset in number of scalars. if (n->dtype.lanes() != 1) { - offset = offset * make_const(offset.dtype(), dtype.lanes()); + PrimExpr last_offset = offsets[offsets.size() - 1]; + offsets.Set(offsets.size() - 1, last_offset * make_const(last_offset.dtype(), dtype.lanes())); } + + // If the requested type has more than one lane, make a RampNode at + // that offset. if (dtype.lanes() != 1) { - return tir::Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); + PrimExpr last_offset = offsets[offsets.size() - 1]; + PrimExpr stride = make_const(last_offset.dtype(), 1); + offsets.Set(offsets.size() - 1, tir::Ramp(last_offset, stride, dtype.lanes())); + } + + return offsets; +} + +Buffer Buffer::GetFlattenedBuffer() const { + auto self = operator->(); + + PrimExpr output_size; + if (self->strides.size()) { + // If strides are defined, then the extent of each flattened + // buffer is the stride*size for the first input axis used for + // each output axis. + ICHECK_EQ(self->shape.size(), self->strides.size()); + output_size = self->strides[0] * self->shape[0]; + } else { - return offset; + // Otherwise, the extent of each flattened buffer is the product + // of the extents of each input axis used to generate that output + // axis. This also "flattens" rank-0 tensors to a rank-1 buffer + // of shape [1]. + + output_size = 1; + for (size_t i = 0; i < self->shape.size(); i++) { + output_size = output_size * self->shape[i]; + } } + + Buffer output = *this; + auto writer = output.CopyOnWrite(); + writer->shape = {output_size}; + + return output; } -PrimExpr Buffer::vload(Array begin, DataType dtype) const { +PrimExpr Buffer::vload(Array begin, DataType value_dtype) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); - ICHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype << " from buffer of " << n->dtype; - if (dtype == DataType::Bool()) { - return tir::Cast(DataType::Bool(), - tir::Load(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), - const_true())); - } else { - return tir::Load(dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); + ICHECK(value_dtype.element_of() == n->dtype.element_of() && + value_dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << value_dtype << " from buffer of " << n->dtype; + + Array indices = begin; + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1) { + indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); } + return BufferLoad(*this, indices); } Stmt Buffer::vstore(Array begin, PrimExpr value) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); - DataType dtype = value.dtype(); - ICHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot store " << dtype << " to buffer of " << n->dtype; - if (value.dtype() == DataType::Bool()) { - return tir::Store(n->data, tir::Cast(DataType::Int(8), value), - BufferOffset(n, begin, DataType::Int(8)), const_true()); - } else { - return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); + DataType value_dtype = value.dtype(); + ICHECK(value_dtype.element_of() == n->dtype.element_of() && + value_dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot store " << value_dtype << " to buffer of " << n->dtype; + + Array indices = begin; + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1) { + indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); } + return BufferStore(*this, value, indices); } String Buffer::scope() const { @@ -353,7 +394,10 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); - PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins)); + Array elem_offset = n->ElemOffset(begins); + elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); }); + ICHECK_EQ(elem_offset.size(), 1) << "MakeSlice currently supports only flat 1-d memory."; + Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; @@ -372,7 +416,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", + return Buffer(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice", n->data_alignment, 0, n->buffer_type); } diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index e9d99cda7e13..96ad060e5896 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -46,13 +46,26 @@ PrimExpr BufferArea(const Buffer& buffer) { } /*! - * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store + * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension */ class BufferFlattener : public StmtExprMutator { public: - static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); } + static PrimFunc Flatten(PrimFunc func) { + auto pass = BufferFlattener(func->buffer_map); + + auto writer = func.CopyOnWrite(); + writer->body = pass.VisitStmt(func->body); + writer->buffer_map = pass.updated_extern_buffer_map_; + return func; + } private: + explicit BufferFlattener(const Map& extern_buffer_map) { + for (const auto& kv : extern_buffer_map) { + updated_extern_buffer_map_.Set(kv.first, MakeFlattenedBuffer(kv.second)); + } + } + Stmt VisitStmt_(const BlockRealizeNode* op) final { // We have convert blocks into opaque blocks in previous passes. ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " @@ -67,8 +80,8 @@ class BufferFlattener : public StmtExprMutator { } // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { - const Buffer& buffer = new_block->alloc_buffers[i - 1]; - body = MakeAllocStmt(buffer, std::move(body)); + Buffer buffer = MakeFlattenedBuffer(new_block->alloc_buffers[i - 1]); + body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body)); } return body; } @@ -112,11 +125,6 @@ class BufferFlattener : public StmtExprMutator { return body; } - Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); - return store->buffer.vstore(store->indices, store->value); - } - PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = unit_loop_vars_.find(var); @@ -131,16 +139,65 @@ class BufferFlattener : public StmtExprMutator { } } + Buffer MakeFlattenedBuffer(Buffer buf) { + ICHECK_EQ(buffer_remap_.count(buf), 0) << "Multiple definitions of " << buf; + + auto flattened = buf.GetFlattenedBuffer(); + + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (flattened->dtype == DataType::Bool()) { + auto writer = flattened.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } + + buffer_remap_[buf] = flattened; + return flattened; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + // Handle casts from the value's dtype to the dtype of the + // backing array. + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (store->value.dtype() == DataType::Bool()) { + ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + auto writer = store.CopyOnWrite(); + writer->value = tir::Cast(DataType::Int(8), store->value); + } + auto flattened_indices = store->buffer->ElemOffset(store->indices); + return VisitBufferAccess(std::move(store)); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + bool load_returns_bool = (op->dtype == DataType::Bool()); BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); - return load->buffer.vload(load->indices, load->dtype); + load = VisitBufferAccess(load); + + // Handle casts from dtype of the backing array to value's dtype. + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + if (load_returns_bool) { + ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + return tir::Cast(DataType::Bool(), load); + } else { + return std::move(load); + } } - static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { - String storage_scope = buffer.scope(); - PrimExpr area = BufferArea(buffer); - body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); - return body; + template + Node VisitBufferAccess(Node node) { + auto flattened_indices = node->buffer->ElemOffset(node->indices); + Buffer flattened_buffer = buffer_remap_.at(node->buffer); + + auto writer = node.CopyOnWrite(); + writer->buffer = flattened_buffer; + writer->indices = flattened_indices; + return node; } static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, @@ -176,14 +233,18 @@ class BufferFlattener : public StmtExprMutator { /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ std::unordered_map unit_loop_vars_; + + /*! \brief Map of buffers being remapped. */ + std::unordered_map buffer_remap_; + + /*! \brief The updated external buffer map. */ + Map updated_extern_buffer_map_; }; PrimFunc FlattenBuffer(PrimFunc f) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = BufferFlattener::Flatten(f); - return f; + return BufferFlattener::Flatten(f); } else { return f; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 6bfbcef95fc5..f956c09f0457 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -185,11 +185,18 @@ class MatchBufferLower : public StmtExprMutator { indices.push_back(range->min); } - Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); - Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); - CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) - << "The source elem_offset " << load->index << " does not satisfy the offset_factor " - << buffer->offset_factor << "."; + auto load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); + if (load->indices.size() == 1) { + Bind(buffer->elem_offset, load->indices[0], buffer->name + ".elem_offset"); + CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + << "The source elem_offset " << load->indices[0] + << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + } else { + // Non-zero elem_offset is ill-defined for non-flat memory. + // If needed in the future, will require `Array + // elem_offsets`, with one offset for each flattened index. + Bind(buffer->elem_offset, 0); + } } // Step 2.3. Check and update strides diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index b4e770363aaa..141f5817dfcd 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1074,9 +1074,12 @@ class StorageFlattener : public StmtExprMutator { bound_analyzer(func->body); + auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer); + auto fptr = func.CopyOnWrite(); - fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer)(std::move(fptr->body)); + fptr->body = pass(std::move(fptr->body)); + fptr->buffer_map = pass.UpdatedBufferMap(); return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1088,12 +1091,25 @@ class StorageFlattener : public StmtExprMutator { for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; + e.flattened_buffer = e.buffer.GetFlattenedBuffer(); + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + + // Boolean tensors are backed by a Int8 array. + if (e.buffer->dtype == DataType::Bool()) { + auto writer = e.buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } e.external = true; buf_map_[kv.second] = e; + + updated_extern_buffer_map_.Set(kv.first, e.flattened_buffer); } cache_line_size_ = cache_line_size; } + Map UpdatedBufferMap() { return updated_extern_buffer_map_; } + Stmt VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; return Stmt(); @@ -1131,7 +1147,18 @@ class StorageFlattener : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); - Stmt body = e.buffer.vstore(op->indices, op->value); + // Handle casts from the value's dtype to the dtype of the backing + // array. + PrimExpr value = op->value; + if (value.dtype() == DataType::Bool()) { + ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + value = tir::Cast(DataType::Int(8), value); + } + + auto flattened_indices = e.buffer->ElemOffset(op->indices); + + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1179,12 +1206,11 @@ class StorageFlattener : public StmtExprMutator { "Please run BufferShapeLegalize first."; } - Array shape = op->buffer->shape; StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); // use small alignment for small arrays auto dtype = op->buffer->dtype; - int32_t const_size = AllocateNode::constant_allocation_size(shape); + int32_t const_size = AllocateNode::constant_allocation_size(op->buffer->shape); int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); @@ -1194,35 +1220,27 @@ class StorageFlattener : public StmtExprMutator { << "Allocation exceed bound of memory tag " << skey.to_string(); } } - Array strides = op->buffer->strides; - e.buffer = Buffer(op->buffer->data, op->buffer->dtype, shape, strides, PrimExpr(), - op->buffer->name, align, 0, kDefault); + e.buffer = Buffer(op->buffer->data, op->buffer->dtype, op->buffer->shape, op->buffer->strides, + PrimExpr(), op->buffer->name, align, 0, kDefault); + e.flattened_buffer = e.buffer.GetFlattenedBuffer(); + + // TODO(Lunderberg): Move the handling of boolean into a + // dedicated pass. + + // Boolean tensors are backed by a Int8 array. + if (e.buffer->dtype == DataType::Bool()) { + auto writer = e.buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); buf_map_[key].in_scope = false; - Stmt ret; - DataType storage_type = e.buffer->dtype; - // specially handle bool, lower its storage - // type to beDataType::Int(8)(byte) - if (storage_type == DataType::Bool()) { - storage_type = DataType::Int(8); - } - if (strides.size() != 0) { - int first_dim = 0; - ret = Allocate(e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); - } else { - shape = e.buffer->shape; - if (shape.size() == 0) { - shape.push_back(make_const(DataType::Int(32), 1)); - } - ret = Allocate(e.buffer->data, storage_type, shape, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); - } + Stmt ret = + Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, + make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, @@ -1250,7 +1268,17 @@ class StorageFlattener : public StmtExprMutator { if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } - return e.buffer.vload(op->indices, e.buffer->dtype); + + auto flattened_indices = e.buffer->ElemOffset(op->indices); + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + + if (op->dtype == DataType::Bool()) { + ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + val = tir::Cast(DataType::Bool(), val); + } + + return val; } Stmt VisitStmt_(const PrefetchNode* op) final { @@ -1330,8 +1358,10 @@ class StorageFlattener : public StmtExprMutator { }; // The buffer entry in the flatten map struct BufferEntry { - // the buffer of storage + // The buffer object Buffer buffer; + // The updated buffer object, after flattening has been applied. + Buffer flattened_buffer; // Whether the buffer is external bool external{false}; // Whether the buffer is currently in scope. @@ -1386,6 +1416,8 @@ class StorageFlattener : public StmtExprMutator { std::unordered_set allocate_node_var_; // Buffer map std::unordered_map buf_map_; + // The extern buffer map, updated to include flattened buffers. + Map updated_extern_buffer_map_; // Collects shapes. std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index a07dda7ae1ad..eae15519947f 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1232,11 +1232,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // divisible by the number of number of lanes, and the predicate // does not apply any masking, then this array access could be // vectorized. - const RampNode* ramp_index = indices[indices.size() - 1].as(); - if (ramp_index && is_one(ramp_index->stride)) { - arith::ModularSet me = analyzer_.modular_set(ramp_index->base); - if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { - lanes_used = ramp_index->lanes; + if (indices.size()) { + const RampNode* ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { + lanes_used = ramp_index->lanes; + } } } From 78d88b5e04889fcb1e1ece8f387dea84db5391a0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 8 Oct 2021 09:00:23 -0500 Subject: [PATCH 014/177] [TE] Added Stage::transform_layout to the C++ TE implementation. Adds an `Array` in the stage to define the transformations to be applied on the tensor's layout. As of this commit, this mapping isn't propagated into the TIR graph yet. --- include/tvm/te/schedule.h | 27 +++++++++ python/tvm/te/schedule.py | 96 +++++++++++++++++++++++++++++++- src/te/schedule/schedule_lang.cc | 9 +++ 3 files changed, 131 insertions(+), 1 deletion(-) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 17aedbcff308..eef01ad1b0f1 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -256,6 +257,29 @@ class Stage : public ObjectRef { * \return reference to self. */ TVM_DLL Stage& rolling_buffer(); // NOLINT(*) + /*! + * \brief Defines a layout transformation to be applied to the buffer. + * + * The map from initial_index to final_index must be an + * invertible affine transformation. + * + * \param initial_indices An array of variables to represent a + * value's location in the tensor, using the pre-transformation + * layout. These variables are used as binding occurrences to + * represent the initial indices when applying the initial->final + * mapping, and should not occur elsewhere in the + * Schedule. (i.e. Pass in newly constructed variables, not the + * initial IterVar::var) + * + * \param final_indices An array of expressions, giving the + * value's location in the tensor, using the post-transformation layout. + * Expressions should be in terms of the variables given in + * initial_indices. + * + * \return reference to self + */ + TVM_DLL Stage& transform_layout(const Array& initial_indices, + const Array& final_indices); /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -500,6 +524,8 @@ class StageNode : public Object { bool double_buffer{false}; /*! \brief Whether apply rolling buffer optimization to this stage */ bool rolling_buffer{false}; + /*! \brief Layout transformations to be applied onto the stage's tensors. */ + Array layout_transforms; /*! * \brief The parent group of the current stage. * The stage cannot be assigned to stages outside the group. @@ -522,6 +548,7 @@ class StageNode : public Object { v->Visit("scope", &scope); v->Visit("is_output", &is_output); v->Visit("double_buffer", &double_buffer); + v->Visit("layout_transforms", &layout_transforms); v->Visit("group", &group); v->Visit("num_child_stages", &num_child_stages); } diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 55d07a57e3e4..5c3eda89d3d9 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -16,12 +16,16 @@ # under the License. # pylint: disable=unused-import """The computation schedule api of TVM.""" +import collections +import inspect +from typing import Callable, List + import tvm._ffi from tvm._ffi.base import string_types from tvm.runtime import Object, convert from tvm.ir import container as _container -from tvm.tir import IterVar, Buffer +from tvm.tir import IterVar, Buffer, Var from . import tensor as _tensor from . import _ffi_api @@ -519,9 +523,99 @@ def rolling_buffer(self): """ _ffi_api.StageRollingBuffer(self) + def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr]]): + """Defines the layout transformation for the current stage's tensor. + + The map from initial_indices to final_indices must be an + invertible affine transformation. + + This method may be called more than once for a given tensor, in which case each + + Parameters + ---------- + mapping_function : Callable[..., List[tvm.tir.PrimExpr]] + + A callable that accepts N arguments of type tvm.tir.Var, + and outputs a list of PrimExpr. The input arguments + represent the location of a value in the current stage's + tensor, using the pre-transformation layout. The return + value of the function gives the location of that value in + the current stage's tensor, using the post-transformation + layout. + + Examples + -------- + .. code-block:: python + + # ``A`` is a tensor whose compute definition is in NHWC + # format, and should be transformed into NCHWc format. + + s[A].transform_layout( + lambda n,h,w,c: [n, c//4, h, w, c%4] + ) + + + .. code-block:: python + + # ``A`` is a tensor whose compute definition is in format, + # and should be transformed such that the last index is + # split, with the slower-chan index of the split placed at the + # slowest changing dimension. + + s[A].transform_layout( + lambda *indices, i: [i//4, *indices, i%4] + ) + + """ + + args = [] + var_arg_name = None + kwargs = collections.OrderedDict() + default_index_dtype = 'int32' + + # Make a dummy variable for each explicitly named input index. + # We may have some keyword-only arguments, if the function has + # *args before the last argument. + params = inspect.signature(mapping_function).parameters + for name, param in params.items(): + if param.kind in [inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]: + args.append(tvm.tir.Var(name, default_index_dtype)) + + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + var_arg_name = name + + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs[name] = tvm.tir.Var(name, default_index_dtype) + + elif param.kind in [inspect.Parameter.VAR_KEYWORD]: + raise ValueError("transform_layout mapping may not have **kwargs") + + ndim = len(self.op.output(0).shape) + + # Now that all the named arguments have been collected, + # everything that remains should go to the *args, if + # specified. + if var_arg_name is not None: + num_var_args = ndim - len(args) - len(kwargs): + for i in range(num_var_args): + args.append(tvm.tir.Var(f'{var_arg_name}[{i}]', default_index_dtype)) + + initial_indices = args + list(kwargs.values()) + if len(initial_indices) != ndim: + raise ValueError( + f"transform_layout mapping accepts {len(params)} initial indices, " + f"but {self.op.name} is {len(self.op.shape)}-dimensional" + ) + + final_indices = mapping_function(*args, **kwargs) + + _ffi_api.StageTransformLayout(self, initial_indices, final_indices) + + @tvm._ffi.register_object class SpecializedCondition(Object): + """Specialized condition to enable op specialization.""" def __init__(self, conditions): diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 2f74d2905454..b7a73e2e3adf 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -429,6 +429,13 @@ Stage& Stage::rolling_buffer() { self->rolling_buffer = true; return *this; } +Stage& Stage::transform_layout(const Array& initial_indices, + const Array& final_indices) { + StageNode* self = operator->(); + + self->layout_transforms.push_back(IndexMap(initial_indices, final_indices)); + return *this; +} Stage CopyStage(const Stage& s) { ObjectPtr n = make_object(*s.operator->()); @@ -895,6 +902,8 @@ TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffe TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer); +TVM_REGISTER_GLOBAL("te.StageTransformLayout").set_body_method(&Stage::transform_layout); + TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); From 4fcb73bacd45c3d006b725c0f5d306115db11718 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 10 Dec 2021 21:19:41 -0600 Subject: [PATCH 015/177] Replace Store/Load with BufferStore/BufferLoad in ir_builder --- python/tvm/tir/ir_builder.py | 94 +++++++++++++++++------------------- src/tir/ir/buffer.cc | 2 + 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index a71476b23e44..52c54f4bf720 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,10 +17,11 @@ """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container, PointerType, PrimType +from tvm.ir import container as _container from . import stmt as _stmt from . import expr as _expr +from . import buffer as _buffer from . import op @@ -43,84 +44,77 @@ class BufferVar(ObjectGeneric): Do not create it directly, create use IRBuilder. - BufferVars support array access either via a linear index, or, if given a - shape, via a multidimensional index. + Array access through a BufferVar must use the same number of + indices as the underlying buffer was declared to have. Examples -------- In the follow example, x is BufferVar. - :code:`x[0] = ...` directly emit a store to the IRBuilder, - :code:`x[10]` translates to Load. + :code:`x[0] = ...` directly emit a BufferStore to the IRBuilder, + :code:`x[10]` translates to BufferLoad. .. code-block:: python - # The following code generate IR for x[0] = x[ + # The following code generate IR for x[0] = x[10] + 1 ib = tvm.tir.ir_builder.create() - x = ib.pointer("float32") + x = ib.allocate("float32", 20) x[0] = x[10] + 1 + # Array access using a multidimensional index y = ib.allocate("float32", (32, 32)) - # Array access using a linear index - y[(2*32) + 31] = 0. - # The same array access using a multidimensional index y[2, 31] = 0. See Also -------- IRBuilder.pointer - IRBuilder.buffer_ptr IRBuilder.allocate + """ - def __init__(self, builder, buffer_var, shape, content_type): + def __init__(self, builder, buffer, content_type): self._builder = builder - self._buffer_var = buffer_var - self._shape = shape + self._buffer = buffer self._content_type = content_type def asobject(self): - return self._buffer_var + return self._buffer @property def dtype(self): return self._content_type - def _linear_index(self, index): - if not isinstance(index, tuple) or self._shape is None: - return index - assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % ( - len(index), - len(self._shape), - ) - dim_size = 1 - lidx = 0 - for dim, idx in zip(reversed(self._shape), reversed(index)): - lidx += idx * dim_size - dim_size *= dim - return lidx + def _normalize_index(self, index): + try: + index = [*index] + except TypeError: + index = [index] - def __getitem__(self, index): t = DataType(self._content_type) - index = self._linear_index(index) if t.lanes > 1: - base = index * t.lanes + base = index[-1] * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index = _expr.Ramp(base, stride, t.lanes) - return _expr.Load(self._content_type, self._buffer_var, index) + index[-1] = _expr.Ramp(base, stride, t.lanes) + + index = [x.var if isinstance(x, _expr.IterVar) else x for x in index] + + return index + + def __getitem__(self, index): + index = self._normalize_index(index) + return _expr.BufferLoad(self._buffer, index) def __setitem__(self, index, value): + index = self._normalize_index(index) + value = convert(value) - if value.dtype != self._content_type: + value_element = value.dtype.split("x", maxsplit=1)[0] + content_element = self._content_type.split("x", maxsplit=1)[0] + if value_element != content_element: raise ValueError( "data type does not match content type %s vs %s" % (value.dtype, self._content_type) ) - index = self._linear_index(index) - t = DataType(self._content_type) - if t.lanes > 1: - base = index * t.lanes - stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index = _expr.Ramp(base, stride, t.lanes) - self._builder.emit(_stmt.Store(self._buffer_var, value, index)) + + self._builder.emit(_stmt.BufferStore(self._buffer, value, index)) class IRBuilder(object): @@ -417,11 +411,14 @@ def allocate(self, dtype, shape, name="buf", scope=""): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] + + buffer = _buffer.decl_buffer(shape, dtype, name, scope=scope) + + buffer_var = buffer.data self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) - return BufferVar(self, buffer_var, shape, dtype) + return BufferVar(self, buffer, dtype) def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. @@ -442,10 +439,10 @@ def pointer(self, content_type, name="ptr", scope=""): ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) - return BufferVar(self, buffer_var, None, content_type) + buffer = _buffer.decl_buffer(shape=[], dtype=content_type, name=name, scope=scope) + return BufferVar(self, buffer, content_type) - def buffer_ptr(self, buf, shape=None): + def buffer_ptr(self, buf): """Create pointer variable corresponds to buffer ptr. Parameters @@ -453,15 +450,12 @@ def buffer_ptr(self, buf, shape=None): buf : Buffer The buffer to be extracted. - shape : Tuple - Optional shape of the buffer. Overrides existing buffer shape. - Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype) + return BufferVar(self, buf, buf.dtype) def likely(self, expr): """Add likely tag for expression. diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index a30615598658..41d16652c635 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -508,6 +508,8 @@ TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); +TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer); + TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); From 97d6dc90ca681f0a2b9bc141bde72affeb0085be Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 11 Oct 2021 09:38:05 -0500 Subject: [PATCH 016/177] [TE] Added Stage.transform_layout to the Python TE interface. Allows users to specify `s[A].transform_layout(mapping)`, and propagate into the TE definitions. --- include/tvm/tir/stmt.h | 7 +++++++ python/tvm/te/schedule.py | 11 +++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index edb789b0bd7f..393cefb87411 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1280,6 +1280,13 @@ constexpr const char* pragma_tensor_core = "pragma_tensor_core"; * run prefetch of Tensor on the current loop scope */ constexpr const char* prefetch_scope = "prefetch_scope"; +/*! + * \brief Marks the physical layout to be used for a tensor. + * + * Only applies to a DataProducer, as it should be made part of the + * Buffer definition in a PrimFunc. + */ +constexpr const char* physical_layout = "physical_layout"; /*! * \brief Marks production of double buffer data */ diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 5c3eda89d3d9..b8f0ab7438fa 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -571,14 +571,17 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr args = [] var_arg_name = None kwargs = collections.OrderedDict() - default_index_dtype = 'int32' + default_index_dtype = "int32" # Make a dummy variable for each explicitly named input index. # We may have some keyword-only arguments, if the function has # *args before the last argument. params = inspect.signature(mapping_function).parameters for name, param in params.items(): - if param.kind in [inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]: + if param.kind in [ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ]: args.append(tvm.tir.Var(name, default_index_dtype)) elif param.kind == inspect.Parameter.VAR_POSITIONAL: @@ -596,9 +599,9 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr # everything that remains should go to the *args, if # specified. if var_arg_name is not None: - num_var_args = ndim - len(args) - len(kwargs): + num_var_args = ndim - len(args) - len(kwargs) for i in range(num_var_args): - args.append(tvm.tir.Var(f'{var_arg_name}[{i}]', default_index_dtype)) + args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype)) initial_indices = args + list(kwargs.values()) if len(initial_indices) != ndim: From a5235744096a81f4abdf068e3ced11248319bf42 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 Nov 2021 08:48:58 -0600 Subject: [PATCH 017/177] Added pre_flattened_shape/pre_flattened_stride fields to Buffer. The shape and stride checks performed in ArgBinder::BindDLTensor (called from MakePackedAPI) require the tensor shape/strides prior to index flattening. Therefore, though it is no longer used by the low-level code generators, we must maintain that information for use in MakePackedAPI. --- include/tvm/tir/buffer.h | 42 +++++++++++++++- src/tir/ir/buffer.cc | 12 ++++- src/tir/transforms/arg_binder.cc | 85 +++++++++++++++++++------------- 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index abdaf6816a14..b06698662e26 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -55,8 +55,39 @@ class BufferNode : public Object { Var data; /*! \brief data type in the content of the tensor */ DataType dtype; - /*! \brief The shape of the buffer */ + /*! \brief The shape of the buffer + * + * This contains the shape as it is accessed by + * BufferLoad/BufferStore nodes, and used by the low-level code + * generators. + */ Array shape; + /*! \brief The shape of the buffer prior to flattening + * + * This contains the shape as it exists prior to flattening, and is + * used for validating the shape of the tensor passed into the + * packed API. + * + * TODO(Lunderberg): Should this be a reference to the entire + * pre-flattened Buffer instead of just the shape? That would also + * allow the PackedFunc to know how ArgBinder::BindDLTensor (called + * from MakePackedAPI) to know how the tensor should be flattened as + * it is being transferred from the device. + */ + Optional> pre_flattened_shape; + /*! \brief The strides of the buffer prior to flattening + * + * This contains the strides as they exists prior to flattening, and + * is used for validating an input tensor passed into the packed + * API. + * + * TODO(Lunderberg): Should this be a reference to the entire + * pre-flattened Buffer instead of just the strides? That would + * also allow the PackedFunc to know how ArgBinder::BindDLTensor + * (called from MakePackedAPI) to know how the tensor should be + * flattened as it is being transferred from the device. + */ + Optional> pre_flattened_strides; /*! * \brief The strides of each dimension * This can be an empty array, indicating array is contiguous @@ -88,6 +119,8 @@ class BufferNode : public Object { v->Visit("data", &data); v->Visit("dtype", &dtype); v->Visit("shape", &shape); + v->Visit("pre_flattened_shape", &pre_flattened_shape); + v->Visit("pre_flattened_strides", &pre_flattened_strides); v->Visit("strides", &strides); v->Visit("elem_offset", &elem_offset); v->Visit("name", &name); @@ -101,7 +134,10 @@ class BufferNode : public Object { // Use DefEqual as buffer can define variables // in its semantics, skip name as name is not important. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && - equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && + equal.DefEqual(shape, other->shape) && + equal.DefEqual(pre_flattened_shape, other->pre_flattened_shape) && + equal.DefEqual(pre_flattened_strides, other->pre_flattened_strides) && + equal.DefEqual(strides, other->strides) && equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } @@ -110,6 +146,8 @@ class BufferNode : public Object { hash_reduce.DefHash(data); hash_reduce(dtype); hash_reduce.DefHash(shape); + hash_reduce.DefHash(pre_flattened_shape); + hash_reduce.DefHash(pre_flattened_strides); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); hash_reduce(data_alignment); diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 41d16652c635..e0f4466e9007 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -322,9 +322,19 @@ Buffer Buffer::GetFlattenedBuffer() const { } } + // If a flattening pass is called multiple times, then the + // pre-flattened shape/strides should be from before the first + // application of the pass. + auto pre_flattened_shape = (*this)->pre_flattened_shape.value_or(self->shape); + auto pre_flattened_strides = (*this)->pre_flattened_strides.value_or(self->strides); + Buffer output = *this; auto writer = output.CopyOnWrite(); - writer->shape = {output_size}; + writer->pre_flattened_shape = pre_flattened_shape; + writer->pre_flattened_strides = pre_flattened_strides; + writer->shape = output_shape; + writer->axis_separators = output_axis_separators; + writer->strides = {}; return output; } diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 013297c2550c..eac6ab804495 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -154,9 +154,30 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const Stmt nop = Evaluate(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); - PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); + ICHECK(buffer->pre_flattened_shape) + << "Cannot bind tensor argument to an unflattened buffer. " + << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; + auto pre_flattened_shape = buffer->pre_flattened_shape.value(); + + ICHECK(buffer->pre_flattened_strides) + << "Cannot bind tensor argument to an unflattened buffer. " + << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; + auto pre_flattened_strides = buffer->pre_flattened_strides.value(); + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + auto stride_handle_name = [&]() { return arg_name + ".strides"; }; + auto array_element_name = [&](const std::string& arr_name, size_t k) { + std::stringstream ss; + ss << arr_name << '[' << k << ']'; + return ss.str(); + }; + auto shape_element_name = [&](size_t k) { return array_element_name(shape_handle_name(), k); }; + auto stride_element_name = [&](size_t k) { return array_element_name(stride_handle_name(), k); }; + + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(pre_flattened_shape.size())); std::ostringstream ndim_err_msg; - ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); + ndim_err_msg << arg_name << ".ndim is expected to equal " << pre_flattened_shape.size(); auto msg = tvm::tir::StringImm(ndim_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks @@ -184,43 +205,42 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::Int(32), buffer->data_alignment), nop)); } - Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type, - arg_name + ".shape"); - Var v_shape(arg_name + ".shape", DataType::Handle()); + // shape field + Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), pre_flattened_shape.size())}, + tvm_shape_type, shape_handle_name()); + Var v_shape(shape_handle_name(), DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); - for (size_t k = 0; k < buffer->shape.size(); ++k) { + for (size_t k = 0; k < pre_flattened_shape.size(); ++k) { if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; } - std::ostringstream field_name; - field_name << v_shape->name_hint << '[' << k << ']'; - Bind_(buffer->shape[k], - cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), - field_name.str(), true); + Bind_( + pre_flattened_shape[k], + cast(pre_flattened_shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), + shape_element_name(k), true); } // strides field - Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, + Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), pre_flattened_strides.size())}, tvm_shape_type, arg_name + ".strides"); def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmt( buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); - if (buffer->strides.size() == 0) { + if (pre_flattened_strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); PrimExpr expect_stride = make_const(stype, 1); Array conds; - for (size_t i = buffer->shape.size(); i != 0; --i) { + for (size_t i = pre_flattened_shape.size(); i != 0; --i) { size_t k = i - 1; PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); conds.push_back(expect_stride == svalue); - expect_stride = expect_stride * buffer->shape[k]; + expect_stride = expect_stride * pre_flattened_shape[k]; } std::ostringstream stride_err_msg; - stride_err_msg << arg_name << ".strides:" - << " expected to be compact array"; + stride_err_msg << stride_handle_name() << ": expected to be compact array"; if (conds.size() != 0) { auto stride_msg = tvm::tir::StringImm(stride_err_msg.str()); Stmt check = AssertStmt( @@ -233,33 +253,28 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); PrimExpr stride = make_const(stype, 1); - for (size_t i = buffer->shape.size(); i != 0; --i) { + for (size_t i = pre_flattened_shape.size(); i != 0; --i) { size_t k = i - 1; - std::ostringstream field_name; - field_name << buf_strides->name << '[' << k << ']'; - PrimExpr value = - cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr value = cast(pre_flattened_shape[k].dtype(), + BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); value = tvm::if_then_else(v_strides_is_null, stride, value); - value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); - Bind_(buffer->strides[k], value, field_name.str(), true); - stride = analyzer_.Simplify(stride * buffer->shape[k]); + value = tvm::if_then_else(pre_flattened_shape[k] == 1, 0, value); + Bind_(pre_flattened_strides[k], value, stride_element_name(k), true); + stride = analyzer_.Simplify(stride * pre_flattened_shape[k]); } } else { PrimExpr stride_from_shape = 1; - for (int k = buffer->strides.size() - 1; k >= 0; k--) { - std::ostringstream field_name; - field_name << buf_strides->name << '[' << k << ']'; - - PrimExpr explicit_stride = - cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + for (int k = pre_flattened_strides.size() - 1; k >= 0; k--) { + PrimExpr explicit_stride = cast(pre_flattened_shape[k].dtype(), + BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); - Bind_(buffer->strides[k], + Bind_(pre_flattened_strides[k], tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), - field_name.str(), true); + stride_element_name(k), true); - stride_from_shape *= - cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); + stride_from_shape *= cast(pre_flattened_shape[k].dtype(), + BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); } } // Byte_offset field. From 3211d617d094cbf5b8d8be98f34d94c0827693a8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Oct 2021 12:32:43 -0500 Subject: [PATCH 018/177] [UnitTest] Test N-d indices exposed to low-level codegen When using te.AXIS_SEPARATOR in the call to .transform_layout, this should define groups of axes, each of which is flattened to a single axis, then exposed to the low-level codegen. --- .../python/unittest/test_transform_layout.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 5b343cd4c8dd..66d03dd67eda 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -220,5 +220,64 @@ def inner(node): walk_buffer_interactions(body, check_shape(reordered_shape)) +class Test2DPhysicalLayout: + transform_A = tvm.testing.parameter( + by_dict={ + "2d_A": True, + "1d_A": False, + } + ) + transform_B = tvm.testing.parameter( + by_dict={ + "2d_B": True, + "1d_B": False, + } + ) + + @staticmethod + def extract_loop_vars(stmt): + output = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + output.append(node.loop_var) + + post_order_visit(stmt, callback) + return output[::-1] + + def test_2d_physical(self, dtype, transform_A, transform_B): + logical_shape = (2, 3, 4) + A = te.placeholder(shape=logical_shape, dtype=dtype, name="A") + B = te.compute(shape=A.shape, fcompute=lambda i, j, k: A[i, j, k], name="B") + + s = te.create_schedule(B.op) + + if transform_A: + s[A].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k]) + + if transform_B: + s[B].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k]) + + mod = tvm.lower(s, [A, B]) + + i, j, k = self.extract_loop_vars(mod["main"].body) + indices_1d = [i * (logical_shape[1] * logical_shape[2]) + j * logical_shape[2] + k] + indices_2d = [i * logical_shape[1] + j, k] + + def callback(node): + if type(node) in [tvm.tir.BufferLoad, tvm.tir.BufferStore]: + name = node.buffer.name + if name == "A": + expected_indices = indices_2d if transform_A else indices_1d + elif name == "B": + expected_indices = indices_2d if transform_B else indices_1d + else: + raise RuntimeError(f"Unexpected buffer: {name}") + + tvm.ir.assert_structural_equal(expected_indices, node.indices) + + post_order_visit(mod["main"].body, callback) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From 74d75dcc70745a4afcae191b3ea9080e5acb3a73 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 11 Oct 2021 09:39:14 -0500 Subject: [PATCH 019/177] [TIR] Added PrimFunc attribute "layout_transform_map", filled from TE. Propagated the TE definition of the physical layout into the TIR graph. --- include/tvm/tir/stmt.h | 6 +- src/te/schedule/schedule_ops.cc | 25 +++- .../schedule/schedule_postproc_to_primfunc.cc | 117 ++++++++++++++++-- 3 files changed, 129 insertions(+), 19 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 393cefb87411..37d6193697f1 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1281,12 +1281,12 @@ constexpr const char* pragma_tensor_core = "pragma_tensor_core"; */ constexpr const char* prefetch_scope = "prefetch_scope"; /*! - * \brief Marks the physical layout to be used for a tensor. + * \brief Marks the layout transforms to be used for a tensor. * * Only applies to a DataProducer, as it should be made part of the - * Buffer definition in a PrimFunc. + * PrimFunc attributes for TIR. */ -constexpr const char* physical_layout = "physical_layout"; +constexpr const char* layout_transforms = "layout_transforms"; /*! * \brief Marks production of double buffer data */ diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 1568df4670af..62763d8d51d6 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -40,12 +40,26 @@ namespace te { using namespace tir; +// Annotate the statement with the physical layout of the stage. This +// annotation is removed during SchedulePostProcToPrimFunc, where it +// becomes part of the PrimFunc attrs. +Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) { + if (stage->layout_transforms.size()) { + for (int i = 0; i < stage->op->num_outputs(); i++) { + body = AttrStmt(Array{stage->op.output(i), stage->layout_transforms}, + tir::attr::layout_transforms, 1, body); + } + } + return body; +} + Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_map, Stmt consumer, bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); if (s->double_buffer) { producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer); } + producer = WrapLayoutTransformationAttrs(s, producer); Stmt pipeline = producer; if (consumer.defined() && !is_no_op(consumer)) { @@ -343,12 +357,16 @@ Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_tri Stage s = sch->stages[i - 1]; ICHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops"; ICHECK(s->op.defined()); - // no need to specify place holder op. - if (s->op.as()) continue; // Remove grouping sugar, get the real attach spec. Stage attach_spec = s.GetAttachSpec(); - if (scan_init.count(s->op)) { + if (s->op.as()) { + // Placeholders don't need any realize/provide statements, but + // may be annotated with set_physical_layout to indicate the + // physical layout of an input, and must still have the + // attribute given. + body = WrapLayoutTransformationAttrs(s, std::move(body)); + } else if (scan_init.count(s->op)) { ICHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); body = mu(std::move(body)); @@ -375,6 +393,7 @@ Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_tri << body; } } + SchedulePostProc post_proc; post_proc.Init(sch); return post_proc(std::move(body)); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 7e8b12b6d61e..55c91e7f52b9 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -42,6 +42,7 @@ #include #include +#include #include #include @@ -55,6 +56,7 @@ Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { name += ".v" + std::to_string(tensor->value_index); } Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); + return buffer; } @@ -86,6 +88,16 @@ class TensorToBufferMapper : public StmtExprMutator { Tensor tensor = Downcast(op->node); Buffer buffer = GetOrAllocBuffer(tensor); return AttrStmt(buffer, op->attr_key, op->value, op->body); + } else if (op->attr_key == tir::attr::layout_transforms) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + Stmt body = op->body; + + Tensor tensor = Downcast(arr[0]); + Buffer buffer = GetBuffer(tensor); + + return AttrStmt(Array{buffer, arr[1]}, op->attr_key, 1, body); } else { return ret; } @@ -134,46 +146,125 @@ class TensorToBufferMapper : public StmtExprMutator { return buffer; } - // maps tensor to buffer. + // Maps tensor to buffer. std::unordered_map buffer_map_; }; +/*! Collect the physical layout map of all tensors in the statement. */ +class LayoutTransformAttrUnwrapper : StmtExprMutator { + public: + static tir::PrimFunc Apply(tir::PrimFunc func) { + // Collect the physical layout annotations in the body, which may + // refer to input arguments. + auto layout_map = Collector::Collect(func->body); + + if (layout_map.size()) { + func = WithAttr(std::move(func), "layout_transform_map", layout_map); + + auto write_ptr = func.CopyOnWrite(); + write_ptr->body = LayoutTransformAttrUnwrapper()(func->body); + } + + return func; + } + + LayoutTransformAttrUnwrapper() {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (op->attr_key == tir::attr::layout_transforms) { + return op->body; + } else { + return ret; + } + } + + private: + /*! Collect the physical layout information of all tensors in the statement. + * + * Must be done before constructing the buffers, since the + * attributes could either apply to the external buffers or to + * internal allocations. + */ + class Collector : StmtExprVisitor { + public: + static Map> Collect(Stmt stmt) { + Collector collector; + collector(std::move(stmt)); + return std::move(collector.layout_map_); + } + + Collector() {} + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::layout_transforms) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + auto buffer = Downcast(arr[0]); + auto layout_transforms = Downcast>(arr[1]); + layout_map_.Set(buffer, layout_transforms); + } + StmtExprVisitor::VisitStmt_(op); + } + + private: + Map> layout_map_; + }; + + std::unordered_map buffer_remap_; + + Map> layout_map_; +}; + PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> extern_buffer_opt) { - std::unordered_map extern_buffer; + std::unordered_map extern_tensor_map; if (extern_buffer_opt.defined()) { auto v = extern_buffer_opt.value(); - extern_buffer = std::unordered_map(v.begin(), v.end()); + extern_tensor_map = std::unordered_map(v.begin(), v.end()); } Array params; Map buffer_map; - for (auto var : arg_list) { - if (auto* n = var.as()) { + for (auto arg : arg_list) { + if (auto* n = arg.as()) { + tir::Var var = GetRef(n); params.push_back(GetRef(n)); - } else if (auto* n = var.as()) { + } else if (auto* n = arg.as()) { te::Tensor tensor = GetRef(n); - ICHECK(!extern_buffer.count(tensor)); + ICHECK(!extern_tensor_map.count(tensor)); tir::Buffer buffer = CreateBufferFor(tensor); tir::Var bptr(buffer->name, PrimType(DataType::Handle())); params.push_back(bptr); buffer_map.Set(bptr, buffer); - extern_buffer[tensor] = buffer; - } else { - tir::Buffer buffer = Downcast(var); + extern_tensor_map[tensor] = buffer; + } else if (auto* n = arg.as()) { + tir::Buffer buffer = GetRef(n); tir::Var bptr(buffer->name, PrimType(DataType::Handle())); params.push_back(bptr); buffer_map.Set(bptr, buffer); + } else { + LOG(FATAL) << "Expected argument to be Var, Tensor, or Buffer, but received " + << arg->GetTypeKey(); } } - body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); + body = TensorToBufferMapper(std::move(extern_tensor_map))(std::move(body)); + + PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map); + + func = LayoutTransformAttrUnwrapper::Apply(std::move(func)); + // We mark this PrimFunc as coming from a TE schedule - return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), "from_legacy_te_schedule", - Bool(true)); + func = WithAttr(func, "from_legacy_te_schedule", Bool(true)); + + return func; } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") From ace752553eea5c990d90ab3e402f47470524d328 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 6 Jan 2022 16:16:37 -0600 Subject: [PATCH 020/177] Added pre_flattened_type. If a boolean tensor is backed by an int8 buffer, the check on the argument buffer's type should be against the boolean type. When rebasing this PR, should be placed after the addition of pre_flatten_shape/pre_flatten_strides. --- include/tvm/tir/buffer.h | 18 +++++++++++++++++- src/tir/ir/buffer.cc | 4 ++++ src/tir/transforms/arg_binder.cc | 20 +++++++++++++------- src/tir/transforms/storage_flatten.cc | 10 ++++++++-- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index b06698662e26..fdd9e3fd31a7 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -55,13 +55,26 @@ class BufferNode : public Object { Var data; /*! \brief data type in the content of the tensor */ DataType dtype; - /*! \brief The shape of the buffer + /*! \brief The type of the buffer prior to flattening * * This contains the shape as it is accessed by * BufferLoad/BufferStore nodes, and used by the low-level code * generators. */ Array shape; + /*! \brief The shape of the buffer prior to flattening + * + * This contains the shape as it exists prior to flattening, and is + * used for validating the shape of the tensor passed into the + * packed API. + * + * TODO(Lunderberg): Should this be a reference to the entire + * pre-flattened Buffer instead of just the shape? That would also + * allow the PackedFunc to know how ArgBinder::BindDLTensor (called + * from MakePackedAPI) to know how the tensor should be flattened as + * it is being transferred from the device. + */ + DataType pre_flattened_dtype; /*! \brief The shape of the buffer prior to flattening * * This contains the shape as it exists prior to flattening, and is @@ -119,6 +132,7 @@ class BufferNode : public Object { v->Visit("data", &data); v->Visit("dtype", &dtype); v->Visit("shape", &shape); + v->Visit("pre_flattened_type", &pre_flattened_dtype); v->Visit("pre_flattened_shape", &pre_flattened_shape); v->Visit("pre_flattened_strides", &pre_flattened_strides); v->Visit("strides", &strides); @@ -135,6 +149,7 @@ class BufferNode : public Object { // in its semantics, skip name as name is not important. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && equal.DefEqual(shape, other->shape) && + equal(pre_flattened_dtype, other->pre_flattened_dtype) && equal.DefEqual(pre_flattened_shape, other->pre_flattened_shape) && equal.DefEqual(pre_flattened_strides, other->pre_flattened_strides) && equal.DefEqual(strides, other->strides) && @@ -146,6 +161,7 @@ class BufferNode : public Object { hash_reduce.DefHash(data); hash_reduce(dtype); hash_reduce.DefHash(shape); + hash_reduce(pre_flattened_dtype); hash_reduce.DefHash(pre_flattened_shape); hash_reduce.DefHash(pre_flattened_strides); hash_reduce.DefHash(strides); diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index e0f4466e9007..daa6174772af 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -327,9 +327,12 @@ Buffer Buffer::GetFlattenedBuffer() const { // application of the pass. auto pre_flattened_shape = (*this)->pre_flattened_shape.value_or(self->shape); auto pre_flattened_strides = (*this)->pre_flattened_strides.value_or(self->strides); + auto pre_flattened_dtype = + (*this)->pre_flattened_dtype == DataType::Void() ? self->dtype : (*this)->pre_flattened_dtype; Buffer output = *this; auto writer = output.CopyOnWrite(); + writer->pre_flattened_dtype = pre_flattened_dtype; writer->pre_flattened_shape = pre_flattened_shape; writer->pre_flattened_strides = pre_flattened_strides; writer->shape = output_shape; @@ -474,6 +477,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array auto n = make_object(); n->data = std::move(data); n->dtype = dtype; + n->pre_flattened_dtype = DataType::Void(); n->shape = std::move(shape); n->strides = std::move(strides); diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index eac6ab804495..c2fa721b8849 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -164,6 +164,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; auto pre_flattened_strides = buffer->pre_flattened_strides.value(); + ICHECK_NE(buffer->pre_flattened_dtype, DataType::Void()) + << "Cannot bind tensor argument to an unflattened buffer. " + << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; + DataType pre_flattened_dtype = buffer->pre_flattened_dtype; + // Helper functions for shape/stride name formatting auto shape_handle_name = [&]() { return arg_name + ".shape"; }; auto stride_handle_name = [&]() { return arg_name + ".strides"; }; @@ -181,16 +186,16 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, auto msg = tvm::tir::StringImm(ndim_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks - DataType dtype = buffer->dtype; std::ostringstream type_err_msg; - type_err_msg << arg_name << ".dtype is expected to be " << dtype; + type_err_msg << arg_name << ".dtype is expected to be " << pre_flattened_dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == - IntImm(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), pre_flattened_dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == - IntImm(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), pre_flattened_dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == - IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { + IntImm(DataType::UInt(16), pre_flattened_dtype.lanes())); + if (!(pre_flattened_dtype == DataType::Int(4) || pre_flattened_dtype == DataType::UInt(4) || + pre_flattened_dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); @@ -213,7 +218,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, init_nest_.emplace_back( LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); for (size_t k = 0; k < pre_flattened_shape.size(); ++k) { - if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { + if (pre_flattened_dtype == DataType::Int(4) || pre_flattened_dtype == DataType::UInt(4) || + pre_flattened_dtype == DataType::Int(1)) { break; } Bind_( diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 141f5817dfcd..6f8c07a89a3f 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1097,8 +1097,14 @@ class StorageFlattener : public StmtExprMutator { // Boolean tensors are backed by a Int8 array. if (e.buffer->dtype == DataType::Bool()) { - auto writer = e.buffer.CopyOnWrite(); - writer->dtype = DataType::Int(8); + { + auto writer = e.buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } + { + auto writer = e.flattened_buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } } e.external = true; buf_map_[kv.second] = e; From 89081de1da7dbe52e959be03acc7f3f77ae9b47d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 10 Dec 2021 18:04:28 -0600 Subject: [PATCH 021/177] [UnitTest] Added tests for loop iteration order. After transformation, the iteration order should follow the new transformed axes. In addition, the loop iteration variables should be exposed through the TE interface for further manipulation. --- .../python/unittest/test_transform_layout.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 66d03dd67eda..fb85463f1cb8 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -279,5 +279,142 @@ def callback(node): post_order_visit(mod["main"].body, callback) +class TestTransformedSchedules: + logical_shape = tvm.testing.parameter((4, 6, 40)) + + transform_names = [ + None, + "reverse", + "flatten_all", + "factor_last_by_4", + ] + + transform_A = tvm.testing.parameter(by_dict={f"A_{t}": t for t in transform_names}) + transform_B = tvm.testing.parameter( + by_dict={f"B_{t}": t for t in transform_names if t is not None} + ) + + after_transform = tvm.testing.parameter(None) + + def make_transform(self, logical_shape, transform_name): + if transform_name is None: + return lambda *indices: indices + elif transform_name == "reverse": + return lambda *indices: indices[::-1] + elif transform_name == "flatten_all": + return flatten_all_indices(logical_shape) + elif transform_name == "factor_last_by_4": + return lambda *indices, n: [*indices, n // 4, n % 4] + else: + raise NotImplementedError(f"Unknown transformation {transform_name}") + + def make_transformed_shape(self, logical_shape, transform_name): + if transform_name is None: + return logical_shape + elif transform_name == "reverse": + return logical_shape[::-1] + elif transform_name == "flatten_all": + num_elements = functools.reduce(lambda x, y: x * y, logical_shape, 1) + return [num_elements] + elif transform_name == "factor_last_by_4": + *indices, n = logical_shape + return [*indices, n // 4, 4] + else: + raise NotImplementedError(f"Unknown transformation {transform_name}") + + @tvm.testing.fixture + def expected_loop_order(self, logical_shape, transform_B, after_transform): + shape = self.make_transformed_shape(logical_shape, transform_B) + + if after_transform == "reorder": + shape = shape[::-1] + + elif after_transform == "split": + shape = [ + *shape[:-1], + 2, + shape[-1] // 2, + ] + + elif after_transform == "fuse": + fused_size = shape[0] if transform_B == "flatten_all" else shape[0] * shape[1] + shape = [fused_size, *shape[2:]] + + return shape + + @tvm.testing.fixture + def schedule(self, logical_shape, dtype, transform_A, transform_B, after_transform): + A = te.placeholder(shape=logical_shape, dtype=dtype, name="A") + B = te.compute(shape=A.shape, fcompute=lambda i, j, k: A[i, j, k], name="B") + + s = te.create_schedule(B.op) + + if transform_A: + s[A].transform_layout(self.make_transform(logical_shape, transform_A)) + + iter_vars = s[B].transform_layout(self.make_transform(logical_shape, transform_B)) + iter_vars = list(iter_vars) + + if after_transform == "reorder": + s[B].reorder(*iter_vars[::-1]) + + elif after_transform == "split": + s[B].split(iter_vars[-1], nparts=2) + + elif after_transform == "fuse": + to_fuse = iter_vars[:2] + s[B].fuse(*iter_vars[:2]) + + return { + "schedule": s, + "tensors": [A, B], + "iter_vars": iter_vars, + } + + def compare_tir_loop_order(self, stmt, expected_loop_order): + def collect_loops(node): + output = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + output.append(node) + + post_order_visit(node, callback) + return output[::-1] + + loops = collect_loops(stmt) + loop_order = [loop.extent for loop in loops] + + np.testing.assert_array_equal(loop_order, expected_loop_order) + + def test_tir_loop_order(self, schedule, expected_loop_order): + func = tvm.lower(schedule["schedule"], schedule["tensors"])["main"] + self.compare_tir_loop_order(func.body, expected_loop_order) + + def test_te_loop_order(self, schedule, expected_loop_order): + s = schedule["schedule"] + A, B = schedule["tensors"] + iter_vars = schedule["iter_vars"] + + # No reduction axis, so all leaf_iter_vars are over the data + # array, and should have the new iteration variables. + extents = [int(iter_var.dom.extent) for iter_var in s[B].leaf_iter_vars] + np.testing.assert_array_equal(extents, expected_loop_order) + + # layout_transform should return the new iteration variables. + extents = [int(iter_var.dom.extent) for iter_var in iter_vars] + np.testing.assert_array_equal(extents, expected_loop_order) + + @pytest.mark.parametrize("after_transform", ["reorder", "split", "fuse"]) + def test_use_transformed_axes( + self, schedule, expected_loop_order, transform_A, transform_B, after_transform + ): + s = schedule["schedule"] + A, B = schedule["tensors"] + + func = tvm.lower(s, [A, B])["main"] + self.compare_tir_loop_order(func.body, expected_loop_order) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From 699bb179d4bdecc909eba4782e0fef281fae828f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 10 Dec 2021 21:46:04 -0600 Subject: [PATCH 022/177] [TIR] Added BufferNode::axis_separators - Add axis_separators to represent divisions between groups of tensor axes, where each group is flattened into a single output axis, to be exposed to the low-level code generators. - Expose axis_separators to the python interface. - Update existing C++ calls to the Buffer() constructor. --- include/tvm/tir/buffer.h | 18 ++++- python/tvm/tir/buffer.py | 11 +++ python/tvm/tir/ir_builder.py | 13 ++- src/printer/tir_text_printer.cc | 3 + src/tir/analysis/device_constraint_utils.cc | 5 +- src/tir/ir/buffer.cc | 87 +++++++++++++++++---- 6 files changed, 115 insertions(+), 22 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index fdd9e3fd31a7..2840f98da8e7 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -101,6 +101,15 @@ class BufferNode : public Object { * flattened as it is being transferred from the device. */ Optional> pre_flattened_strides; + /*! + * \brief Separators between input axes when generating flattened output axes + * + * For buffers representing flat 1-d memory (e.g. any buffer in + * RAM), this should be an empty array. For buffers representing + * non-flat memory, each entry in axis_separators should be the + * first input axis that is part of a new flattened axis. + */ + Array axis_separators; /*! * \brief The strides of each dimension * This can be an empty array, indicating array is contiguous @@ -136,6 +145,7 @@ class BufferNode : public Object { v->Visit("pre_flattened_shape", &pre_flattened_shape); v->Visit("pre_flattened_strides", &pre_flattened_strides); v->Visit("strides", &strides); + v->Visit("axis_separators", &axis_separators); v->Visit("elem_offset", &elem_offset); v->Visit("name", &name); v->Visit("data_alignment", &data_alignment); @@ -153,6 +163,7 @@ class BufferNode : public Object { equal.DefEqual(pre_flattened_shape, other->pre_flattened_shape) && equal.DefEqual(pre_flattened_strides, other->pre_flattened_strides) && equal.DefEqual(strides, other->strides) && + equal.DefEqual(axis_separators, other->axis_separators) && equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } @@ -166,6 +177,7 @@ class BufferNode : public Object { hash_reduce.DefHash(pre_flattened_strides); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); + hash_reduce.DefHash(axis_separators); hash_reduce(data_alignment); hash_reduce(buffer_type); } @@ -200,7 +212,7 @@ class Buffer : public ObjectRef { // A default value will be picked. TVM_DLL Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Span span = Span()); + BufferType buffer_type, Array axis_separators = {}, Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -260,12 +272,14 @@ class Buffer : public ObjectRef { * \param dtype The content data type. * \param name The name of the buffer * \param storage_scope The storage scope associated with this buffer + * \param axis_separators Divisions defining the groups of axes that will be flattened together. * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", Span span = Span()); + String name = "buffer", String storage_scope = "", + Array axis_separators = {}, Span span = Span()); /*! * \brief Base node for data producers. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 6dddd7b119a0..12947bab49a4 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -155,6 +155,7 @@ def decl_buffer( data_alignment=-1, offset_factor=0, buffer_type="", + axis_separators=None, span=None, ): """Declare a new symbolic buffer. @@ -204,6 +205,11 @@ def decl_buffer( without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1. + axis_separators : list of int, optional + If passed, a list of separators between groups of axes, + each of which is flattened to an output axis. For flat + memory spaces, should either be None, or an empty list. + span: Optional[Span] The location of the decl_buffer creation in the source. @@ -254,6 +260,10 @@ def decl_buffer( shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape dtype = "float32" if dtype is None else dtype strides = () if strides is None else strides + + if axis_separators is None: + axis_separators = [] + if offset_factor != 0 and elem_offset is None: shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32" elem_offset = Var("%s_elem_offset" % name, shape_dtype) @@ -272,6 +282,7 @@ def decl_buffer( data_alignment, offset_factor, buffer_type, + axis_separators, span, ) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 52c54f4bf720..f95b6b73fadb 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -388,7 +388,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=""): + def allocate(self, dtype, shape, name="buf", axis_separators=None, scope=""): """Create a allocate statement. Parameters @@ -402,6 +402,12 @@ def allocate(self, dtype, shape, name="buf", scope=""): name : str, optional The name of the buffer. + axis_separators : list of int, optional + + If passed, a list of separators between groups of axes, + each of which is flattened to an output axis. For flat + memory spaces, should either be None, or an empty list. + scope : str, optional The scope of the buffer. @@ -410,11 +416,14 @@ def allocate(self, dtype, shape, name="buf", scope=""): ------- buffer : BufferVar The buffer var representing the buffer. + """ if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - buffer = _buffer.decl_buffer(shape, dtype, name, scope=scope) + buffer = _buffer.decl_buffer( + shape, dtype, name, scope=scope, axis_separators=axis_separators + ) buffer_var = buffer.data self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index a9804229da91..7f79168a14f4 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -223,6 +223,9 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { if (!is_zero(buf->elem_offset)) { doc << ", elem_offset=" << Print(buf->elem_offset); } + if (buf->axis_separators.size()) { + doc << ", axis_separators=" << Print(buf->axis_separators); + } if (GetRef(buf).scope() != "global") { doc << ", scope=" << Doc::StrLiteral(GetRef(buf).scope()); } diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 26cf66c4d4c0..9a1e5ba38cad 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -425,9 +425,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope); Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span); var_subst_.emplace(buffer->data.get(), new_data); - Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset, - buffer->name, buffer->data_alignment, buffer->offset_factor, - buffer->buffer_type, buffer->span); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->data = new_data; buffer_subst_.emplace(buffer.get(), new_buffer); return new_buffer; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index daa6174772af..4f0d7de4e202 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,10 +48,10 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { } Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, - Span span) { + Array axis_separators, Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, - Array(), PrimExpr(), name, 0, 0, kDefault, span); + Array(), PrimExpr(), name, 0, 0, kDefault, axis_separators, span); } // Split the given expression w.r.t the add operator @@ -256,11 +256,33 @@ Array BufferNode::ElemOffset(Array input_indices) const { << "the index's dimensionality must match the dimensionality of the index given."; } - PrimExpr output_index = 0; + // TODO(Lunderberg): Better handling for cases where there is more + // than one output index. Currently, this only allows elem_offset + // to be non-zero for flat memory allocations. + Array elem_offsets = {}; + if (elem_offset.defined() && !is_zero(elem_offset)) { + elem_offsets = {elem_offset}; + } + + if (elem_offsets.size()) { + ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1) + << "If element offsets are defined, " + << "there must be one element offset for each output index."; + } + + Array output_indices(axis_separators.size() + 1, 0); + + size_t current_output_axis = 0; arith::Analyzer ana; for (size_t i = 0; i < input_indices.size(); i++) { + if ((current_output_axis < axis_separators.size()) && + (i == size_t(axis_separators[current_output_axis]->value))) { + current_output_axis++; + } + + PrimExpr output_index = output_indices[current_output_axis]; if (strides.size()) { output_index = output_index + input_indices[i] * strides[i]; } else { @@ -270,13 +292,17 @@ Array BufferNode::ElemOffset(Array input_indices) const { if (i > 0) { output_index = MergeMulMod(&ana, output_index); } + + output_indices.Set(current_output_axis, output_index); } - if (elem_offset.defined() && !is_zero(elem_offset)) { - output_index = output_index + elem_offset; + if (elem_offsets.size()) { + for (size_t i = 0; i < output_indices.size(); i++) { + output_indices.Set(i, output_indices[i] + elem_offsets[i]); + } } - return {output_index}; + return output_indices; } inline Array BufferOffset(const BufferNode* n, Array index, DataType dtype) { @@ -302,26 +328,56 @@ inline Array BufferOffset(const BufferNode* n, Array index, Buffer Buffer::GetFlattenedBuffer() const { auto self = operator->(); - PrimExpr output_size; + // These checks ensure that all output axes contain at least one + // input axis. + for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) { + auto sep = self->axis_separators[i]->value; + auto next_sep = self->axis_separators[i]->value; + ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order."; + } + if (self->axis_separators.size()) { + auto first_sep = self->axis_separators[0]->value; + ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, " + << "so that first output axis contains at least one input axis"; + auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value; + ICHECK_LT(last_sep, self->shape.size()) + << "Last output axis must contain at least one input axis."; + } + + Array output_shape; if (self->strides.size()) { // If strides are defined, then the extent of each flattened // buffer is the stride*size for the first input axis used for // each output axis. ICHECK_EQ(self->shape.size(), self->strides.size()); - output_size = self->strides[0] * self->shape[0]; + output_shape.push_back(self->strides[0] * self->shape[0]); + for (const auto& sep : self->axis_separators) { + output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]); + } } else { // Otherwise, the extent of each flattened buffer is the product // of the extents of each input axis used to generate that output // axis. This also "flattens" rank-0 tensors to a rank-1 buffer // of shape [1]. - - output_size = 1; + output_shape = Array(self->axis_separators.size() + 1, 1); + size_t current_output_index = 0; for (size_t i = 0; i < self->shape.size(); i++) { - output_size = output_size * self->shape[i]; + if ((current_output_index < self->axis_separators.size()) && + (i == size_t(self->axis_separators[current_output_index]->value))) { + current_output_index += 1; + } + output_shape.Set(current_output_index, output_shape[current_output_index] * self->shape[i]); } } + // The axis_separators for the output buffer. + Array output_axis_separators; + for (size_t i = 0; i < self->axis_separators.size(); i++) { + auto dtype = self->axis_separators[i]->dtype; + output_axis_separators.push_back(IntImm(dtype, i + 1)); + } + // If a flattening pass is called multiple times, then the // pre-flattened shape/strides should be from before the first // application of the pass. @@ -464,7 +520,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Span span) { + BufferType buffer_type, Array axis_separators, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { @@ -481,6 +537,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array n->shape = std::move(shape); n->strides = std::move(strides); + n->axis_separators = std::move(axis_separators); n->name = std::move(name); if (!elem_offset.defined()) { elem_offset = make_const(n->DefaultIndexType(), 0); @@ -513,11 +570,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 10); + ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].operator String(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = - Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]); + *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, + args[9], args[10]); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); From c3ff6f6852039febbaf3988f9b5951692bacd044 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Oct 2021 13:53:26 -0500 Subject: [PATCH 023/177] [TIR] Added ApplyLayoutTransforms as part of StorageFlatten. For any buffers that have layout transforms defined in the "layout_transform_map" attribute of a PrimFunc, rewrite access into the buffer such that they use the updated ordering. --- include/tvm/ir/attrs.h | 41 ++++++++++ include/tvm/tir/stmt.h | 2 + src/tir/transforms/storage_flatten.cc | 112 ++++++++++++++++++++++++++ 3 files changed, 155 insertions(+) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index f6c15f9590df..9a2468714962 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -382,6 +382,47 @@ inline TFunc WithAttrs(TFunc input, Map attrs) { return input; } +/*! + * \brief Copy the function or module, but removes the specified + * attribute. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attr_key The attribute key. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with removed attribute. + * + * \note This function performs copy on write optimization for func and module. + * If we move a uniquely referenced func or module into WithoutAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithoutAttr(std::move(func), "key1"); + * func = WithoutAttr(std::move(func), "key2"); + * + * \endcode + */ +template +inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + + if (input->attrs.defined()) { + TNode* node = input.CopyOnWrite(); + node->attrs.CopyOnWrite()->dict.erase(attr_key); + if (node->attrs->dict.size() == 0) { + node->attrs = NullValue(); + } + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 37d6193697f1..bd7528439655 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -388,6 +388,7 @@ class BufferRealize : public Stmt { Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); }; /*! @@ -583,6 +584,7 @@ class Allocate : public Stmt { Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode); }; /*! diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6f8c07a89a3f..e66eb6b4054d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1066,6 +1066,117 @@ class BufferBindUnwrapper : public StmtExprMutator { IRVisitorWithAnalyzer* bound_analyzer_; }; +class ApplyLayoutTransforms : public StmtExprMutator { + public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + auto lookup = func->attrs.GetAttr>>("layout_transform_map"); + + if (!lookup) { + return func; + } + + Map> layout_transforms = lookup.value(); + + auto fptr = func.CopyOnWrite(); + + auto mutator = ApplyLayoutTransforms(layout_transforms); + fptr->buffer_map = mutator.UpdateExternBufferMap(fptr->buffer_map); + fptr->body = mutator(std::move(fptr->body)); + + return WithoutAttr(std::move(func), "layout_transform_map"); + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.ApplyLayoutTransforms", {}); + } + + explicit ApplyLayoutTransforms(Map> layout_transforms) + : layout_transforms_(layout_transforms) {} + + Map UpdateExternBufferMap(const Map& buffer_map) { + Map output; + for (const auto& kv : buffer_map) { + output.Set(kv.first, GetBufferRemap(kv.second, true)); + } + return output; + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + // Call once so that load/store nodes can read from the cached + // value. + GetBufferRemap(op->buffer, true); + + auto realize = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto lookup = layout_transforms_.Get(op->buffer); + if (lookup) { + auto write_ptr = realize.CopyOnWrite(); + write_ptr->buffer = GetBufferRemap(op->buffer, true); + + Array transforms = lookup.value(); + for (const auto& transform : transforms) { + write_ptr->bounds = transform->MapRanges(realize->bounds); + } + } + + return std::move(realize); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + auto lookup = layout_transforms_.Get(node->buffer); + if (lookup) { + auto write_ptr = node.CopyOnWrite(); + + write_ptr->buffer = GetBufferRemap(node->buffer); + + Array transforms = lookup.value(); + for (const auto& transform : transforms) { + write_ptr->indices = transform->MapIndices(node->indices); + } + } + return node; + } + + private: + //! \brief Given a buffer, return the buffer it should be remapped into. + Buffer GetBufferRemap(Buffer buf, bool allow_alloc = false) { + auto key = buf.get(); + auto it = buf_map_.find(key); + if (it != buf_map_.end()) { + return it->second; + } + + ICHECK(allow_alloc) << "Buffer " << buf << " accessed before declaration."; + + auto lookup = layout_transforms_.Get(buf); + if (lookup) { + Array transforms = lookup.value(); + + auto write_ptr = buf.CopyOnWrite(); + for (const auto& transform : transforms) { + write_ptr->shape = transform->MapShape(buf->shape); + } + } + + buf_map_[key] = buf; + return buf; + } + + std::unordered_map buf_map_; + + Map> layout_transforms_; +}; + class StorageFlattener : public StmtExprMutator { public: static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) { @@ -1522,6 +1633,7 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at BufferStrideLegalize::Pass(), ThreadScopePropagate::Pass(), BufferBindUnwrapper::Pass(), + ApplyLayoutTransforms::Pass(), StorageFlattener::Pass(cache_line_size, create_bound_attributes), AssertSimplifier::Pass(), }, From acad83f125e776708041c72a6bef38dd9176b891 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 13 Dec 2021 10:22:21 -0600 Subject: [PATCH 024/177] Update usage of ir_builder where necessary. --- python/tvm/topi/cuda/sparse.py | 12 ++++++------ tests/python/unittest/test_target_codegen_vulkan.py | 5 ++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 32f20a15016e..8bfc8032bfef 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -149,7 +149,6 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) m = data.shape[1] nb = w_indptr.shape[0] - 1 - nnzb = w_data.shape[0] # treat csr like block size 1 bsr if len(w_data.shape) == 1: bs_n = 1 @@ -181,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k)) + w_data_ptr = ib.buffer_ptr(w_data) w_indices_ptr = ib.buffer_ptr(w_indices) w_indptr_ptr = ib.buffer_ptr(w_indptr) @@ -238,10 +237,11 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): elem_idx = bb * rowlength_bi + tx with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.for_range(0, bs_k, name="z", kind="unroll") as z: - if use_warp_storage: - w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] - else: - w_data_cache[warp, tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + data_indices = [row_start + elem_idx] + ( + [y, z] if len(w_data.shape) > 1 else [] + ) + cache_indices = [tx, y, z] if use_warp_storage else [warp, tx, y, z] + w_data_cache[cache_indices] = w_data_ptr[data_indices] with ib.for_range(0, mi, name="i") as i: # thread local block matmul with ib.for_range(0, bs_m, name="x", kind="unroll") as x: diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 7b708cbe0c12..bde1ca4d0a58 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -502,10 +502,9 @@ def do_compute(ins, outs): store_index = index_map[store_type] if indirect_indices: - load_index = tvm.tir.expr.Load("int32x4", R, load_index) + load_index = R[load_index] - transfer = tvm.tir.expr.Load("int32x4", A, load_index) - ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + B[store_index] = A[load_index] return ib.get() From a62f449e34255ec5a294a39a87f0fbe42f097dd8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Dec 2021 12:50:00 -0600 Subject: [PATCH 025/177] [TE] Implement te::Transform Similar to Fuse and Split, this represents a modification to the existing loop iterations. --- include/tvm/te/schedule.h | 30 +++++++ src/te/schedule/message_passing.cc | 130 +++++++++++++++++++++++++++++ src/te/schedule/schedule_lang.cc | 10 +++ 3 files changed, 170 insertions(+) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 24e04fc685d1..e7c9df3892e8 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -813,6 +813,36 @@ class Singleton : public IterVarRelation { TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); }; +/*! + * \brief Transform iterator according to some arbitrary expression. + */ +class TransformNode : public IterVarRelationNode { + public: + Array original_variables; + Array transformed_variables; + IndexMap forward_transformation; + IndexMap inverse_transformation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("original_variables", &original_variables); + v->Visit("transformed_variables", &transformed_variables); + v->Visit("forward_transformation", &forward_transformation); + v->Visit("inverse_transformation", &inverse_transformation); + } + + static constexpr const char* _type_key = "Transform"; + TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode); +}; + +class Transform : public IterVarRelation { + public: + TVM_DLL explicit Transform(Array original_variables, + Array transformed_variables, IndexMap forward_transformation, + IndexMap inverse_transformation); + + TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode); +}; + /*! \brief Container for specialization conditions. */ class SpecializedConditionNode : public Object { public: diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index d45f29ebc5b6..b1056ac2447d 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -79,6 +79,22 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* } else if (const RebaseNode* s = rel.as()) { state[s->parent] = state[s->rebased]; } else if (rel.as()) { + } else if (const TransformNode* s = rel.as()) { + // Currently, this marks all original iter vars as deriving from + // a thread bind if any of the transformed variables are bound, + // even if the inverse expression for that iter var doesn't + // depend on the bound variable. + + // TODO(Lunderberg): For each of original variable, check + // whether any variable in the inverse expression for it has a + // thread binding. + bool is_thread_binding = false; + for (const auto& iter_var : s->transformed_variables) { + is_thread_binding = is_thread_binding || state[iter_var]; + } + for (const auto& iter_var : s->original_variables) { + state[iter_var] = is_thread_binding; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -157,6 +173,29 @@ void PassDownDomain(const Stage& stage, std::unordered_map* p_st Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx); } else if (const SingletonNode* s = rel.as()) { Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx); + } else if (const TransformNode* s = rel.as()) { + bool missing_originals = false; + for (const auto& iter_var : s->original_variables) { + if (!state.count(iter_var)) { + ICHECK(allow_missing); + missing_originals = true; + } + } + if (missing_originals) { + continue; + } + + Array original_ranges; + for (const auto& iter_var : s->original_variables) { + original_ranges.push_back(state[iter_var]); + } + Array updated_ranges = s->forward_transformation->MapRanges(original_ranges); + + ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size()); + for (size_t i = 0; i < updated_ranges.size(); i++) { + Update(p_state, s->transformed_variables[i], updated_ranges[i], actx); + } + } else { LOG(FATAL) << "unknown relation type"; } @@ -225,6 +264,29 @@ void PassUpIndex(const Stage& stage, const Map& dom_map, state[s->parent] = value; } } else if (rel.as()) { + } else if (const TransformNode* s = rel.as()) { + bool missing_transformed = false; + for (const auto& iter_var : s->transformed_variables) { + if (!state.count(iter_var)) { + ICHECK(allow_missing); + missing_transformed = true; + } + } + if (missing_transformed) { + continue; + } + + Array transformed_indices; + for (const auto& iter_var : s->transformed_variables) { + transformed_indices.push_back(state[iter_var]); + } + Array original_indices = s->inverse_transformation->MapIndices(transformed_indices); + + ICHECK_EQ(original_indices.size(), s->original_variables.size()); + for (size_t i = 0; i < original_indices.size(); i++) { + state[s->original_variables[i]] = original_indices[i]; + } + } else { LOG(FATAL) << "unknown relation type"; } @@ -270,6 +332,28 @@ void PassDownIndex(const Stage& stage, const Map& dom_map, state[s->rebased] = value; } else if (const SingletonNode* s = rel.as()) { state[s->iter] = make_zero(s->iter->var.dtype()); + } else if (const TransformNode* s = rel.as()) { + bool missing_originals = false; + for (const auto& iter_var : s->original_variables) { + if (!state.count(iter_var)) { + ICHECK(allow_missing); + missing_originals = true; + } + } + if (missing_originals) { + continue; + } + + Array original_indices; + for (const auto& iter_var : s->original_variables) { + original_indices.push_back(state[iter_var]); + } + Array transformed_indices = s->forward_transformation->MapIndices(original_indices); + + ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size()); + for (size_t i = 0; i < transformed_indices.size(); i++) { + state[s->transformed_variables[i]] = transformed_indices[i]; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -351,6 +435,26 @@ void PassUpDomain(const RebaseNode* s, const std::unordered_map& *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); } +Array PassUpDomain(const TransformNode* s, + const std::unordered_map& dom_map, + const Map& transformed_domains) { + Array output; + + Array transformed_indices; + for (const auto& iter_var : s->transformed_variables) { + transformed_indices.push_back(iter_var->var); + } + + Array transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices); + + ICHECK_EQ(transformed_exprs.size(), s->original_variables.size()); + for (size_t i = 0; i < transformed_exprs.size(); i++) { + output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains)); + } + + return output; +} + void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; @@ -370,6 +474,16 @@ void PassUpDomain(const Stage& stage, const std::unordered_map& PassUpDomain(r, dom_map, state.at(r->rebased), &parent); state[r->parent] = parent; } else if (rel.as()) { + } else if (const TransformNode* r = rel.as()) { + Map transformed_domains; + for (const auto& var : r->transformed_variables) { + transformed_domains.Set(var, state.at(var)); + } + auto original_ranges = PassUpDomain(r, dom_map, transformed_domains); + ICHECK_EQ(original_ranges.size(), r->original_variables.size()); + for (size_t i = 0; i < original_ranges.size(); i++) { + state[r->original_variables[i]] = original_ranges[i]; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -509,6 +623,22 @@ void PassUpBoundCheck(const Stage& s, const Map& dom_map, state[s->parent] = state.at(s->rebased); } else if (rel.as()) { // nop + } else if (const TransformNode* s = rel.as()) { + // Currently, this marks all original iter vars as requiring + // bounds checks if any of the transformed variables require + // bounds checks, even if the inverse expression for that iter + // var doesn't depend on the bound variable. + + // TODO(Lunderberg): For each of original variable, check + // whether any variable in the inverse expression for it + // requires bounds checking. + bool needs_bounds_check = false; + for (const auto& iter_var : s->transformed_variables) { + needs_bounds_check = needs_bounds_check || state[iter_var]; + } + for (const auto& iter_var : s->original_variables) { + state[iter_var] = needs_bounds_check; + } } else { LOG(FATAL) << "unknown relation type"; } diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index da61355a9e1d..beb13c66f7a2 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -723,6 +723,16 @@ Singleton::Singleton(IterVar iter) { data_ = std::move(n); } +Transform::Transform(Array original_variables, Array transformed_variables, + IndexMap forward_transformation, IndexMap inverse_transformation) { + auto n = make_object(); + n->original_variables = original_variables; + n->transformed_variables = transformed_variables; + n->forward_transformation = forward_transformation; + n->inverse_transformation = inverse_transformation; + data_ = std::move(n); +} + SpecializedCondition::SpecializedCondition(Array conditions) { ObjectPtr n = make_object(); n->clauses = std::move(conditions); From 25ff74c6e664a96288673b06ade1a3e2b0edb61d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 13 Oct 2021 11:28:28 -0500 Subject: [PATCH 026/177] [TE] Added Stage::set_axis_separators. In C++, this is implemented as an `Array`, specifying pre-flatteneing axes after which a new post-flattening should be started. The python interface uses a sentinel value `te.AXIS_SEPARATOR` in the call to `transform_layout`, which is then used to define the array of axis separators. --- include/tvm/te/schedule.h | 15 +++++++++++++++ python/tvm/te/__init__.py | 8 +++++++- python/tvm/te/schedule.py | 23 ++++++++++++++++++++++- src/te/schedule/schedule_lang.cc | 9 ++++++++- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index eef01ad1b0f1..24e04fc685d1 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -280,6 +280,14 @@ class Stage : public ObjectRef { */ TVM_DLL Stage& transform_layout(const Array& initial_indices, const Array& final_indices); + /*! \brief Defines separators between groups of axes. + * + * Used to define `BufferNode::axis_separators`, which has + * additional details. + * + * \param axis_separators A list of axis separators. + */ + TVM_DLL Stage& set_axis_separators(const Array& axis_separators); /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -526,6 +534,12 @@ class StageNode : public Object { bool rolling_buffer{false}; /*! \brief Layout transformations to be applied onto the stage's tensors. */ Array layout_transforms; + /*! \brief List of axes after which to divide physical axes. + * + * Used to populate `BufferNode::axis_separators`, which has + * additional details. + */ + Array axis_separators; /*! * \brief The parent group of the current stage. * The stage cannot be assigned to stages outside the group. @@ -549,6 +563,7 @@ class StageNode : public Object { v->Visit("is_output", &is_output); v->Visit("double_buffer", &double_buffer); v->Visit("layout_transforms", &layout_transforms); + v->Visit("axis_separators", &axis_separators); v->Visit("group", &group); v->Visit("num_child_stages", &num_child_stages); } diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 308257085e51..8b59cc4797bf 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -27,7 +27,13 @@ from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum -from .schedule import Schedule, Stage, create_schedule, SpecializedCondition +from .schedule import ( + Schedule, + Stage, + create_schedule, + SpecializedCondition, + AXIS_SEPARATOR, +) from .tensor import TensorSlice, Tensor from .tensor_intrin import decl_tensor_intrin from .tag import tag_scope diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index b8f0ab7438fa..7b57bd5c98f8 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -610,9 +610,24 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr f"but {self.op.name} is {len(self.op.shape)}-dimensional" ) - final_indices = mapping_function(*args, **kwargs) + mapping = mapping_function(*args, **kwargs) + + final_indices = [] + axis_separators = [] + for val in mapping: + if isinstance(val, tvm.ir.PrimExpr): + final_indices.append(val) + elif val is AXIS_SEPARATOR: + axis_separators.append(len(final_indices)) + else: + raise TypeError( + "Expected mapping function to return list of " + "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. " + "Instead received {val} of type {type(val)}." + ) _ffi_api.StageTransformLayout(self, initial_indices, final_indices) + _ffi_api.StageSetAxisSeparators(self, axis_separators) @@ -652,4 +667,10 @@ def __exit__(self, ptype, value, trace): _ffi_api.ExitSpecializationScope(self) +# Sentinel value used to indicate which groups of pre-flattening axes +# should be used to post-flattening axes axes. See +# Stage.transform_layout for more details. +AXIS_SEPARATOR = "axis_separator" + + tvm._ffi._init_api("schedule", __name__) diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index b7a73e2e3adf..da61355a9e1d 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -432,11 +432,16 @@ Stage& Stage::rolling_buffer() { Stage& Stage::transform_layout(const Array& initial_indices, const Array& final_indices) { StageNode* self = operator->(); - self->layout_transforms.push_back(IndexMap(initial_indices, final_indices)); return *this; } +Stage& Stage::set_axis_separators(const Array& axis_separators) { + StageNode* self = operator->(); + self->axis_separators = axis_separators; + return *this; +} + Stage CopyStage(const Stage& s) { ObjectPtr n = make_object(*s.operator->()); return Stage(n); @@ -904,6 +909,8 @@ TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buf TVM_REGISTER_GLOBAL("te.StageTransformLayout").set_body_method(&Stage::transform_layout); +TVM_REGISTER_GLOBAL("te.StageSetAxisSeparators").set_body_method(&Stage::set_axis_separators); + TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); From 3735b6f0233d154aed4134f57b2f107ab433a397 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Oct 2021 10:22:06 -0500 Subject: [PATCH 027/177] [TIR] Expose tir.transform.ApplyLayoutTransforms for testing --- python/tvm/tir/transform/transform.py | 13 +++++++++++++ src/tir/transforms/storage_flatten.cc | 3 +++ 2 files changed, 16 insertions(+) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 834335766551..8334777ddc01 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -74,6 +74,19 @@ def InjectPrefetch(): return _ffi_api.InjectPrefetch() # type: ignore +def ApplyLayoutTransforms(): + """Reshape buffers that appear in the "layout_transform_map" + fucntion attribute. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + """ + return _ffi_api.ApplyLayoutTransforms() # type: ignore + + def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): """Flatten the multi-dimensional read/write to 1D. diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index e66eb6b4054d..2988129fdca8 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1649,6 +1649,9 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at namespace transform { +TVM_REGISTER_GLOBAL("tir.transform.ApplyLayoutTransforms") + .set_body_typed(ApplyLayoutTransforms::Pass); + // TODO(tvm-team): consolidate configs to the PassContext Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { From ea4c10a1b16befefe2f1e9b3b6376a83beb1a5b1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Dec 2021 12:49:09 -0600 Subject: [PATCH 028/177] [TE] Rewrite loop iteration order After .transform_layout, rewrite leaf_iter_vars to follow the updated order. Use the te::Transform iter_var relationship to track use of the transformed variable. --- include/tvm/te/operation.h | 1 + include/tvm/te/schedule.h | 22 ++++++++++-- src/te/schedule/schedule_lang.cc | 60 +++++++++++++++++++++++++++++++- 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 13f39317dbe4..e91a0930f37b 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -265,6 +265,7 @@ class ComputeOp : public Operation { Array axis, Array body); TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode); }; /*! diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index e7c9df3892e8..fc05a8bd2245 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -498,9 +498,27 @@ class StageNode : public Object { * while origin_op remains fixed. */ Operation origin_op; - /*! \brief All the nodes in the iter var */ + /*! \brief All the nodes in the iter var + * + * Each element of all_iter_vars represents an iteration variable + * that may appear within this stage's computation. Any element + * of `all_iter_vars` that is in `leaf_iter_vars` represents a + * variable that is directly defined and usable within the stage's + * computation. All other elements of `all_iter_vars` represent + * variables whose value must be computed from the variables in + * `leaf_iter_vars`. (e.g. Support index k has been split by + * ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in + * `leaf_iter_vars`, while k will not, and must be computed as + * `4*ko + ki`. + */ Array all_iter_vars; - /*! \brief The current active leaf iter vars in the stage. */ + /*! \brief The current active leaf iter vars in the stage. + * + * Each element of leaf_iter_vars will either be replaced with the + * bound index (e.g. threadIdx.x), or will be expanded into a loop + * over the variable's extent. `leaf_iter_vars` is a subset of + * `all_iter_vars`. + */ Array leaf_iter_vars; /*! * \brief Specify threads to be launched at the stage. diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index beb13c66f7a2..f5af49418b7a 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -25,8 +25,10 @@ #include #include +#include #include #include +#include #include "graph.h" @@ -432,7 +434,63 @@ Stage& Stage::rolling_buffer() { Stage& Stage::transform_layout(const Array& initial_indices, const Array& final_indices) { StageNode* self = operator->(); - self->layout_transforms.push_back(IndexMap(initial_indices, final_indices)); + IndexMap map(initial_indices, final_indices); + self->layout_transforms.push_back(map); + + auto* compute = self->op.as(); + + // Can only rewrite the indices of compute op nodes. + if (!compute) { + return *this; + } + + CHECK_EQ(initial_indices.size(), compute->axis.size()) + << "Expected number of initial indices in transformation to match the dimension of " + << self->op->name; + + // Locate the IterVar objects for the data axes. + auto leaf_iter_range = [&]() -> std::pair { + std::vector leaf_var_indices; + for (const auto& axis : compute->axis) { + leaf_var_indices.push_back( + FindLeafVar(self->all_iter_vars.CopyOnWrite(), self->leaf_iter_vars.CopyOnWrite(), axis)); + } + auto minmax_element = std::minmax_element(leaf_var_indices.begin(), leaf_var_indices.end()); + return {*minmax_element.first, *minmax_element.second + 1}; + }(); + CHECK_EQ(leaf_iter_range.first + compute->axis.size(), leaf_iter_range.second) + << "Cannot transform indices if they have already been reordered"; + + // Determine the updated ranges of iteration. + Array initial_ranges; + for (const auto& iter_var : compute->axis) { + initial_ranges.push_back(iter_var->dom); + } + Array final_ranges = map->MapRanges(initial_ranges); + + // Make IterVar objects to represent the new iterations. + auto inverse = map.Inverse(initial_ranges); + Array final_indices_iter; + ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + final_indices_iter.push_back(IterVar(final_ranges[i], inverse->initial_indices[i], kDataPar)); + } + + // Append the new IterVar objects to all_iter_vars + for (const auto& iter_var : final_indices_iter) { + self->all_iter_vars.push_back(iter_var); + } + + // Replace the existing IterVar objects in leaf_iter_vars with the + // new IterVar objects. + self->leaf_iter_vars.erase(self->leaf_iter_vars.begin() + leaf_iter_range.first, + self->leaf_iter_vars.begin() + leaf_iter_range.second); + self->leaf_iter_vars.insert(self->leaf_iter_vars.begin() + leaf_iter_range.first, + final_indices_iter.begin(), final_indices_iter.end()); + + // Define a relationship for each new axis + self->relations.push_back(Transform(compute->axis, final_indices_iter, map, inverse)); + return *this; } From ce8e29a2bd95451c620ed464294c49e3c14e9204 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Oct 2021 13:14:15 -0500 Subject: [PATCH 029/177] [TE] Fill BufferNode::axis_separators from StageNode During ScheduleOps and SchedulePostprocToPrimfunc, the axis separators defined in the stage must be passed through to the TIR BufferNode. --- include/tvm/tir/stmt.h | 8 ++ src/te/schedule/schedule_ops.cc | 16 ++- .../schedule/schedule_postproc_to_primfunc.cc | 135 +++++++++++++++++- 3 files changed, 154 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index bd7528439655..84a465e7cf08 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1289,6 +1289,14 @@ constexpr const char* prefetch_scope = "prefetch_scope"; * PrimFunc attributes for TIR. */ constexpr const char* layout_transforms = "layout_transforms"; +/*! + * \brief Marks the physical axis separators + * + * Only applies to a DataProducer, as it should be made part of the + * Buffer definition in a PrimFunc. See `BufferNode::axis_separators` + * for more details. + */ +constexpr const char* axis_separators = "axis_separators"; /*! * \brief Marks production of double buffer data */ diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 62763d8d51d6..368121c74bc0 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -40,9 +40,11 @@ namespace te { using namespace tir; -// Annotate the statement with the physical layout of the stage. This -// annotation is removed during SchedulePostProcToPrimFunc, where it -// becomes part of the PrimFunc attrs. +// Annotate the statement with the layout transforms and axis +// separators of the stage. These annotations are removed during +// SchedulePostProcToPrimFunc. Afterwards, layout transforms are +// specified in the PrimFunc attrs, and the axis_separators are +// specified in the BufferNode. Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) { if (stage->layout_transforms.size()) { for (int i = 0; i < stage->op->num_outputs(); i++) { @@ -50,6 +52,14 @@ Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) { tir::attr::layout_transforms, 1, body); } } + + if (stage->axis_separators.size()) { + for (int i = 0; i < stage->op->num_outputs(); i++) { + body = AttrStmt(Array{stage->op.output(i), stage->axis_separators}, + tir::attr::axis_separators, 1, body); + } + } + return body; } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 55c91e7f52b9..e8cd0b387f90 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -88,7 +88,8 @@ class TensorToBufferMapper : public StmtExprMutator { Tensor tensor = Downcast(op->node); Buffer buffer = GetOrAllocBuffer(tensor); return AttrStmt(buffer, op->attr_key, op->value, op->body); - } else if (op->attr_key == tir::attr::layout_transforms) { + } else if (op->attr_key == tir::attr::layout_transforms || + op->attr_key == tir::attr::axis_separators) { auto arr = Downcast>(op->node); ICHECK_EQ(arr.size(), 2); @@ -210,7 +211,6 @@ class LayoutTransformAttrUnwrapper : StmtExprMutator { StmtExprVisitor::VisitStmt_(op); } - private: Map> layout_map_; }; @@ -219,6 +219,136 @@ class LayoutTransformAttrUnwrapper : StmtExprMutator { Map> layout_map_; }; +/*! Move axis_separators from an attribute to a buffer property. */ +class AxisSeparatorsAttrUnwrapper : StmtExprMutator { + public: + static tir::PrimFunc Apply(tir::PrimFunc func) { + // Collect the physical layout annotations in the body, which may + // refer to input arguments. + auto axis_separators_map = Collector::Collect(func->body); + + if (axis_separators_map.size()) { + auto write_ptr = func.CopyOnWrite(); + auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map); + write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map); + write_ptr->body = pass(func->body); + } + + return func; + } + + explicit AxisSeparatorsAttrUnwrapper(Map> axis_separators_map) + : axis_separators_map_(axis_separators_map) {} + + Map UpdateExternBufferMap(const Map& orig) { + Map output; + for (const auto& kv : orig) { + output.Set(kv.first, GetRemappedBuffer(kv.second)); + } + return output; + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (op->attr_key == tir::attr::axis_separators) { + return op->body; + } else { + return ret; + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + private: + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + if (!node->buffer.same_as(new_buf)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; + } + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + // If this buffer has already been remapped, then return the + // previous value. + auto key = buf.get(); + { + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + return it->second; + } + } + + // Otherwise, check if we need to add axis_separators to this + // buffer. + auto lookup = axis_separators_map_.Get(buf); + if (lookup) { + Array axis_separators = lookup.value(); + if (axis_separators.size()) { + auto write_ptr = buf.CopyOnWrite(); + write_ptr->axis_separators = axis_separators; + } + } + + // And cache the result for next time. + buffer_remap_[key] = buf; + + return buf; + } + + /*! Collect the axis separator information of all tensors in the statement. + * + * Must be done before constructing the buffers, since the + * attributes could either apply to the external buffers or to + * internal allocations. + */ + class Collector : StmtExprVisitor { + public: + static Map> Collect(Stmt stmt) { + Collector collector; + collector(std::move(stmt)); + return std::move(collector.axis_separators_map_); + } + + Collector() {} + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::axis_separators) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + auto buffer = Downcast(arr[0]); + auto axis_separators = Downcast>(arr[1]); + axis_separators_map_.Set(buffer, axis_separators); + } + StmtExprVisitor::VisitStmt_(op); + } + + Map> axis_separators_map_; + }; + + std::unordered_map buffer_remap_; + + Map> axis_separators_map_; +}; + PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> extern_buffer_opt) { std::unordered_map extern_tensor_map; @@ -260,6 +390,7 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map); func = LayoutTransformAttrUnwrapper::Apply(std::move(func)); + func = AxisSeparatorsAttrUnwrapper::Apply(std::move(func)); // We mark this PrimFunc as coming from a TE schedule func = WithAttr(func, "from_legacy_te_schedule", Bool(true)); From fb4d71a39f8dcd5e670c2dad0d25f452bbce5e01 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 10 Dec 2021 15:44:39 -0600 Subject: [PATCH 030/177] [TE] Return transformed iteration variables --- include/tvm/te/schedule.h | 6 ++++- python/tvm/te/schedule.py | 46 +++++++++++++++++++++++++++----- src/te/schedule/schedule_lang.cc | 16 +++++++++-- 3 files changed, 58 insertions(+), 10 deletions(-) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index fc05a8bd2245..deafb3f929ee 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -276,10 +276,14 @@ class Stage : public ObjectRef { * Expressions should be in terms of the variables given in * initial_indices. * + * \param out_iter_vars An optional output location for the updated + * loop iteration variables. + * * \return reference to self */ TVM_DLL Stage& transform_layout(const Array& initial_indices, - const Array& final_indices); + const Array& final_indices, + Array* out_iter_vars = nullptr); /*! \brief Defines separators between groups of axes. * * Used to define `BufferNode::axis_separators`, which has diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 7b57bd5c98f8..fdd08f9208c9 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -527,9 +527,15 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr """Defines the layout transformation for the current stage's tensor. The map from initial_indices to final_indices must be an - invertible affine transformation. + invertible affine transformation. This method may be called + more than once for a given tensor, in which case each + transformation is applied sequentially. - This method may be called more than once for a given tensor, in which case each + If the stage is a ComputeOp, then the iteration order of the + compute stage is rewritten to be a row-major traversal of the + tensor, and the new loop iteration variables are returned. + For all other stages, the loop iteration order is unmodified, + and the return value is None. Parameters ---------- @@ -543,6 +549,17 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr the current stage's tensor, using the post-transformation layout. + Returns + ------- + new_iter_vars : Optional[List[tvm.tir.IterVar]] + + If the stage is a ComputeOp, then the return will be the + updated loop iteration variables over the data array, in + the same order as the output values from the + `mapping_function`. + + Otherwise, the return value is None. + Examples -------- .. code-block:: python @@ -557,15 +574,29 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr .. code-block:: python - # ``A`` is a tensor whose compute definition is in format, - # and should be transformed such that the last index is - # split, with the slower-chan index of the split placed at the - # slowest changing dimension. + # ``A`` is a tensor whose compute definition is in an + # arbitrary format, and should be transformed such that + # the last index is split, with the slower-changing index + # of the split placed at the slowest changing dimension. s[A].transform_layout( lambda *indices, i: [i//4, *indices, i%4] ) + .. code-block:: python + + # ``B`` is a tensor defined by te.compute to be a copy of + # ``A`, and should be transformed such that ``B``'s layout + # is a transpose of ``A``'s layout. The loop iteration + # that computes ``B`` will correspond to ``B``'s memory + # layout. + + A = te.placeholder([n,m]) + B = te.compute(A.shape, lambda i,j: A[i,j]) + s = te.create_schedule(B.op) + + s[B].transform_layout(lambda i,j: [j,i]) + """ args = [] @@ -626,9 +657,10 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr "Instead received {val} of type {type(val)}." ) - _ffi_api.StageTransformLayout(self, initial_indices, final_indices) + new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices) _ffi_api.StageSetAxisSeparators(self, axis_separators) + return new_iter_vars or None @tvm._ffi.register_object diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index f5af49418b7a..0fcd6133c4a2 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -432,7 +432,8 @@ Stage& Stage::rolling_buffer() { return *this; } Stage& Stage::transform_layout(const Array& initial_indices, - const Array& final_indices) { + const Array& final_indices, + Array* out_iter_vars) { StageNode* self = operator->(); IndexMap map(initial_indices, final_indices); self->layout_transforms.push_back(map); @@ -491,6 +492,11 @@ Stage& Stage::transform_layout(const Array& initial_indices, // Define a relationship for each new axis self->relations.push_back(Transform(compute->axis, final_indices_iter, map, inverse)); + // Return the iteration variables as an output. + if (out_iter_vars) { + *out_iter_vars = final_indices_iter; + } + return *this; } @@ -975,7 +981,13 @@ TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffe TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer); -TVM_REGISTER_GLOBAL("te.StageTransformLayout").set_body_method(&Stage::transform_layout); +TVM_REGISTER_GLOBAL("te.StageTransformLayout") + .set_body_typed([](Stage stage, const Array& initial_indices, + const Array& final_indices) { + Array new_iter_vars; + stage.transform_layout(initial_indices, final_indices, &new_iter_vars); + return new_iter_vars; + }); TVM_REGISTER_GLOBAL("te.StageSetAxisSeparators").set_body_method(&Stage::set_axis_separators); From e34221d0fbd23077865d7fb140dc2c8663275440 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 11 Jan 2022 12:39:49 -0600 Subject: [PATCH 031/177] Moved Buffer's pre-flatten information to PrimFunc. Since the pre-flatten information is only used for validating user inputs, it makes much more sense to store it alongside the buffer_map. --- include/tvm/tir/buffer.h | 61 +++----------- include/tvm/tir/function.h | 30 +++++++ python/tvm/script/context_maintainer.py | 3 + python/tvm/script/parser.py | 1 + python/tvm/script/tir/special_stmt.py | 8 +- python/tvm/tir/buffer.py | 10 +++ python/tvm/tir/function.py | 37 ++++++++- src/relay/backend/aot_executor_codegen.cc | 2 +- src/relay/backend/build_module.cc | 2 +- .../backend/contrib/cmsisnn/relay_to_tir.cc | 2 +- .../example_target_hooks/relay_to_tir.cc | 2 +- src/tir/analysis/device_constraint_utils.cc | 22 +++++- src/tir/ir/buffer.cc | 12 --- src/tir/ir/function.cc | 9 ++- src/tir/transforms/arg_binder.cc | 79 ++++++++----------- src/tir/transforms/flatten_buffer.cc | 4 + src/tir/transforms/make_packed_api.cc | 8 +- src/tir/transforms/split_host_device.cc | 6 -- src/tir/transforms/storage_flatten.cc | 4 + src/tir/usmp/transform/assign_pool_info.cc | 4 +- .../convert_pool_allocations_to_offsets.cc | 13 +-- tests/python/unittest/test_lower_build.py | 12 +-- 22 files changed, 185 insertions(+), 146 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 2840f98da8e7..1140648fd41b 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -62,45 +62,6 @@ class BufferNode : public Object { * generators. */ Array shape; - /*! \brief The shape of the buffer prior to flattening - * - * This contains the shape as it exists prior to flattening, and is - * used for validating the shape of the tensor passed into the - * packed API. - * - * TODO(Lunderberg): Should this be a reference to the entire - * pre-flattened Buffer instead of just the shape? That would also - * allow the PackedFunc to know how ArgBinder::BindDLTensor (called - * from MakePackedAPI) to know how the tensor should be flattened as - * it is being transferred from the device. - */ - DataType pre_flattened_dtype; - /*! \brief The shape of the buffer prior to flattening - * - * This contains the shape as it exists prior to flattening, and is - * used for validating the shape of the tensor passed into the - * packed API. - * - * TODO(Lunderberg): Should this be a reference to the entire - * pre-flattened Buffer instead of just the shape? That would also - * allow the PackedFunc to know how ArgBinder::BindDLTensor (called - * from MakePackedAPI) to know how the tensor should be flattened as - * it is being transferred from the device. - */ - Optional> pre_flattened_shape; - /*! \brief The strides of the buffer prior to flattening - * - * This contains the strides as they exists prior to flattening, and - * is used for validating an input tensor passed into the packed - * API. - * - * TODO(Lunderberg): Should this be a reference to the entire - * pre-flattened Buffer instead of just the strides? That would - * also allow the PackedFunc to know how ArgBinder::BindDLTensor - * (called from MakePackedAPI) to know how the tensor should be - * flattened as it is being transferred from the device. - */ - Optional> pre_flattened_strides; /*! * \brief Separators between input axes when generating flattened output axes * @@ -141,9 +102,6 @@ class BufferNode : public Object { v->Visit("data", &data); v->Visit("dtype", &dtype); v->Visit("shape", &shape); - v->Visit("pre_flattened_type", &pre_flattened_dtype); - v->Visit("pre_flattened_shape", &pre_flattened_shape); - v->Visit("pre_flattened_strides", &pre_flattened_strides); v->Visit("strides", &strides); v->Visit("axis_separators", &axis_separators); v->Visit("elem_offset", &elem_offset); @@ -155,14 +113,16 @@ class BufferNode : public Object { } bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { - // Use DefEqual as buffer can define variables - // in its semantics, skip name as name is not important. + // Use DefEqual as buffer can define variables in its semantics, + // skip name as name is not important. + + // The pre-flattened information is only used for type-checking, + // and doesn't represent a different computation. + // + // TODO(Lunderberg): Move the pre-flattened buffer information + // into the PrimFunc's buffer_map. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && - equal.DefEqual(shape, other->shape) && - equal(pre_flattened_dtype, other->pre_flattened_dtype) && - equal.DefEqual(pre_flattened_shape, other->pre_flattened_shape) && - equal.DefEqual(pre_flattened_strides, other->pre_flattened_strides) && - equal.DefEqual(strides, other->strides) && + equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && equal.DefEqual(axis_separators, other->axis_separators) && equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); @@ -172,9 +132,6 @@ class BufferNode : public Object { hash_reduce.DefHash(data); hash_reduce(dtype); hash_reduce.DefHash(shape); - hash_reduce(pre_flattened_dtype); - hash_reduce.DefHash(pre_flattened_shape); - hash_reduce.DefHash(pre_flattened_strides); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); hash_reduce.DefHash(axis_separators); diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1ab911b756df..08691a889e13 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -91,11 +91,23 @@ class PrimFuncNode : public BaseFuncNode { */ Map buffer_map; + /*! \brief The buffer map prior to flattening. + * + * This contains the buffers as they exists prior to flattening, and + * is used for validating an input tensor passed into the packed + * API. Any buffer that is present in `buffer_map` but not present + * in `preflattened_buffer_map` is assumed to be the same before + * and after flattening (e.g. a 1-d tensor that is backed by 1-d + * flat memory). + */ + Map preflattened_buffer_map; + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); + v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -104,6 +116,7 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && + equal(preflattened_buffer_map, other->preflattened_buffer_map) && equal(ret_type, other->ret_type) && equal(body, other->body) && equal(attrs, other->attrs); } @@ -111,6 +124,7 @@ class PrimFuncNode : public BaseFuncNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(params); hash_reduce(buffer_map); + hash_reduce(preflattened_buffer_map); hash_reduce(ret_type); hash_reduce(body); hash_reduce(attrs); @@ -136,15 +150,31 @@ class PrimFunc : public BaseFunc { public: /*! * \brief Constructor + * * \param params The parameters of the function. + * * \param body The body of the function. + * * \param ret_type The return type of the function. + * * \param buffer_map The buffer map for parameter buffer unpacking. + * This contains buffer objects as they appear in the body of the + * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally + * generated as a tensor of shape ``[32, 32]``) + * + * \param preflattened_buffer_map The buffer map for + * parameter buffer unpacking. This contains buffer + * objects as they are expected to be passed in by the + * callee. (e.g. a buffer of shape ``[32, 32]`` originally + * generated as a tensor of shape ``[32, 32]``) + * * \param attrs Additional function attributes. + * * \param span The location of this object in the source code. */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = Map(), + Map preflattened_buffer_map = Map(), DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 149e17bcc701..972e5845fcb9 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -127,6 +127,8 @@ class ContextMaintainer: """List[Var]: The function parameters""" func_buffer_map: Mapping[Var, Buffer] = {} """Mapping[Var, Buffer]: The function buffer map""" + func_preflattened_buffer_map: Mapping[Var, Buffer] = {} + """Mapping[Var, Buffer]: The function buffer map, prior to any flattening.""" func_dict_attr: Mapping[str, Object] = {} """Mapping[str, Object]: The function attrs""" func_var_env_dict: Mapping[Var, str] = {} @@ -151,6 +153,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # function context self.func_params = [] self.func_buffer_map = {} + self.func_preflattened_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} # parser and analyzer diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index caf6bc4f6778..dca366bf4269 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -484,6 +484,7 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: body, ret_type, buffer_map=self.context.func_buffer_map, + preflattened_buffer_map=self.context.func_preflattened_buffer_map, attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, span=tvm_span_from_synr(node.span), ) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 20161ad106c1..a513fd087c4b 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -132,6 +132,7 @@ def match_buffer( align=-1, offset_factor=0, buffer_type="default", + flatten_buffer=False, span=None, ): if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: @@ -165,7 +166,12 @@ def match_buffer( self.context.report_error( "Can not bind non-input param to buffer", self.node.rhs.params[0].span ) - self.context.func_buffer_map[param] = buffer + if flatten_buffer: + self.context.func_preflattened_buffer_map[param] = buffer + buffer = buffer.get_flattened_buffer() + self.context.func_buffer_map[param] = buffer + else: + self.context.func_buffer_map[param] = buffer elif isinstance(param, BufferSlice): buffer_region = buffer_slice_to_region(param) self.context.current_block_scope().match_buffers.append( diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 12947bab49a4..d60f6185d0e6 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -143,6 +143,16 @@ def scope(self): """ return _ffi_api.BufferStorageScope(self) # type: ignore + def get_flattened_buffer(self): + """Generate a Buffer that is a flattened version of this buffer. + + Returns + ------- + flattened : Buffer + The corresponding flat buffer. + """ + return _ffi_api.BufferGetFlattenedBuffer(self) # type: ignore + def decl_buffer( shape, diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index bcebab9ddc0a..fdee18f88cf8 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -45,6 +45,9 @@ class PrimFunc(BaseFunc): buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer] The buffer binding map. + preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]] + The buffer binding map, prior to any flattening. + attrs: Optional[tvm.Attrs] Attributes of the function, can be None @@ -52,9 +55,20 @@ class PrimFunc(BaseFunc): The location of this itervar in the source code. """ - def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, span=None): + def __init__( + self, + params, + body, + ret_type=None, + buffer_map=None, + preflattened_buffer_map=None, + attrs=None, + span=None, + ): + param_list = [] buffer_map = {} if buffer_map is None else buffer_map + preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map for x in params: x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): @@ -67,8 +81,15 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore - ) + _ffi_api.PrimFunc, + param_list, + body, + ret_type, + buffer_map, + preflattened_buffer_map, + attrs, + span, + ) # type: ignore def with_body(self, new_body, span=None): """Create a new PrimFunc with the same set signatures but a new body. @@ -86,7 +107,15 @@ def with_body(self, new_body, span=None): new_func : PrimFunc The created new function. """ - return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) + return PrimFunc( + self.params, + new_body, + self.ret_type, + self.buffer_map, + self.preflattened_buffer_map, + self.attrs, + span, + ) def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): """Specialize parameters of PrimFunc diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index dd80e54553d5..d22d1f1b3adf 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -655,7 +655,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, DictAttrs(dict_attrs)); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 2f986669e758..571c1933768b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -444,7 +444,7 @@ class RelayBuildModule : public runtime::ModuleNode { dict.Set(tvm::attr::kGlobalSymbol, String(::tvm::runtime::symbol::tvm_lookup_linked_param)); DictAttrs attrs{dict}; auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), - Map(), attrs); + Map(), {}, attrs); if (lowered_funcs.find(host_target) == lowered_funcs.end()) { lowered_funcs.Set(host_target, IRModule(Map({}), {}, {}, {}, func_module->attrs)); diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index f366e4ab2635..6342646b9864 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -80,7 +80,7 @@ class RelayToTIRVisitor : public MixedModeMutator { } tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), - DictAttrs(dict_attrs)); + Map(), DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 1317ceb7a174..86f55caf9342 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -81,7 +81,7 @@ class ConvertAddToSubtract : public MixedModeMutator { }; tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, DictAttrs(dict_attrs)); + buffer_map, {}, DictAttrs(dict_attrs)); // Switch to TIRToRuntime hook for testing Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 9a1e5ba38cad..1309681513a9 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -210,6 +210,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { // Start with a copy of the current prim_func buffer map. Map new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); + Map new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(), + prim_func->preflattened_buffer_map.end()); bool any_change = false; // For each constrained parameter... @@ -223,6 +225,23 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { any_change = true; } new_buffer_map.Set(param, new_buffer); + + // Rewrite the pre-flattened buffers to account for constraint. + // This only has an impact if the IRModule being analyzed has + // already been run through the StorageFlatten or FlattenBuffer + // passes. + if (auto opt = prim_func->preflattened_buffer_map.Get(param)) { + Buffer pf_buffer = opt.value(); + if (pf_buffer.same_as(buffer)) { + new_preflattened_buffer_map.Set(param, new_buffer); + } else { + const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device); + if (!new_buffer.same_as(pf_buffer)) { + any_change = true; + } + new_preflattened_buffer_map.Set(param, new_buffer); + } + } } // Make sure we have accounted for all prim_func parameters. CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); @@ -240,7 +259,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (any_change) { return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, - std::move(new_buffer_map), prim_func->attrs, prim_func->span); + std::move(new_buffer_map), std::move(new_preflattened_buffer_map), + prim_func->attrs, prim_func->span); } else { return prim_func; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 4f0d7de4e202..2ec2f49f0c69 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -378,19 +378,8 @@ Buffer Buffer::GetFlattenedBuffer() const { output_axis_separators.push_back(IntImm(dtype, i + 1)); } - // If a flattening pass is called multiple times, then the - // pre-flattened shape/strides should be from before the first - // application of the pass. - auto pre_flattened_shape = (*this)->pre_flattened_shape.value_or(self->shape); - auto pre_flattened_strides = (*this)->pre_flattened_strides.value_or(self->strides); - auto pre_flattened_dtype = - (*this)->pre_flattened_dtype == DataType::Void() ? self->dtype : (*this)->pre_flattened_dtype; - Buffer output = *this; auto writer = output.CopyOnWrite(); - writer->pre_flattened_dtype = pre_flattened_dtype; - writer->pre_flattened_shape = pre_flattened_shape; - writer->pre_flattened_strides = pre_flattened_strides; writer->shape = output_shape; writer->axis_separators = output_axis_separators; writer->strides = {}; @@ -533,7 +522,6 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array auto n = make_object(); n->data = std::move(data); n->dtype = dtype; - n->pre_flattened_dtype = DataType::Void(); n->shape = std::move(shape); n->strides = std::move(strides); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 1c34e34468b5..058f350059cd 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -37,7 +37,8 @@ LinkedParam::LinkedParam(int64_t id, ::tvm::runtime::NDArray param) { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { + Map buffer_map, Map preflattened_buffer_map, + DictAttrs attrs, Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -48,6 +49,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->body = std::move(body); n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); + n->preflattened_buffer_map = std::move(preflattened_buffer_map); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); n->span = std::move(span); @@ -126,8 +128,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { - return PrimFunc(params, body, ret_type, buffer_map, attrs, span); + Map buffer_map, + Map preflattened_buffer_map, DictAttrs attrs, Span span) { + return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span); }); TVM_REGISTER_GLOBAL("tir.TensorIntrin") diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index c2fa721b8849..19a08e0d30cc 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -154,20 +154,6 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const Stmt nop = Evaluate(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); - ICHECK(buffer->pre_flattened_shape) - << "Cannot bind tensor argument to an unflattened buffer. " - << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; - auto pre_flattened_shape = buffer->pre_flattened_shape.value(); - - ICHECK(buffer->pre_flattened_strides) - << "Cannot bind tensor argument to an unflattened buffer. " - << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; - auto pre_flattened_strides = buffer->pre_flattened_strides.value(); - - ICHECK_NE(buffer->pre_flattened_dtype, DataType::Void()) - << "Cannot bind tensor argument to an unflattened buffer. " - << "Please run StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules) first."; - DataType pre_flattened_dtype = buffer->pre_flattened_dtype; // Helper functions for shape/stride name formatting auto shape_handle_name = [&]() { return arg_name + ".shape"; }; @@ -180,22 +166,22 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, auto shape_element_name = [&](size_t k) { return array_element_name(shape_handle_name(), k); }; auto stride_element_name = [&](size_t k) { return array_element_name(stride_handle_name(), k); }; - PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(pre_flattened_shape.size())); + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; - ndim_err_msg << arg_name << ".ndim is expected to equal " << pre_flattened_shape.size(); + ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); auto msg = tvm::tir::StringImm(ndim_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks std::ostringstream type_err_msg; - type_err_msg << arg_name << ".dtype is expected to be " << pre_flattened_dtype; + type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == - IntImm(DataType::UInt(8), pre_flattened_dtype.code()) && + IntImm(DataType::UInt(8), buffer->dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == - IntImm(DataType::UInt(8), pre_flattened_dtype.bits()) && + IntImm(DataType::UInt(8), buffer->dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == - IntImm(DataType::UInt(16), pre_flattened_dtype.lanes())); - if (!(pre_flattened_dtype == DataType::Int(4) || pre_flattened_dtype == DataType::UInt(4) || - pre_flattened_dtype == DataType::Int(1))) { + IntImm(DataType::UInt(16), buffer->dtype.lanes())); + if (!(buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); @@ -211,39 +197,38 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } // shape field - Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), pre_flattened_shape.size())}, - tvm_shape_type, shape_handle_name()); + Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type, + shape_handle_name()); Var v_shape(shape_handle_name(), DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); - for (size_t k = 0; k < pre_flattened_shape.size(); ++k) { - if (pre_flattened_dtype == DataType::Int(4) || pre_flattened_dtype == DataType::UInt(4) || - pre_flattened_dtype == DataType::Int(1)) { + for (size_t k = 0; k < buffer->shape.size(); ++k) { + if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { break; } - Bind_( - pre_flattened_shape[k], - cast(pre_flattened_shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), - shape_element_name(k), true); + Bind_(buffer->shape[k], + cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), + shape_element_name(k), true); } // strides field - Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), pre_flattened_strides.size())}, + Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, tvm_shape_type, arg_name + ".strides"); def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmt( buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); - if (pre_flattened_strides.size() == 0) { + if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); PrimExpr expect_stride = make_const(stype, 1); Array conds; - for (size_t i = pre_flattened_shape.size(); i != 0; --i) { + for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); conds.push_back(expect_stride == svalue); - expect_stride = expect_stride * pre_flattened_shape[k]; + expect_stride = expect_stride * buffer->shape[k]; } std::ostringstream stride_err_msg; stride_err_msg << stride_handle_name() << ": expected to be compact array"; @@ -259,28 +244,28 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); PrimExpr stride = make_const(stype, 1); - for (size_t i = pre_flattened_shape.size(); i != 0; --i) { + for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr value = cast(pre_flattened_shape[k].dtype(), - BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr value = + cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); value = tvm::if_then_else(v_strides_is_null, stride, value); - value = tvm::if_then_else(pre_flattened_shape[k] == 1, 0, value); - Bind_(pre_flattened_strides[k], value, stride_element_name(k), true); - stride = analyzer_.Simplify(stride * pre_flattened_shape[k]); + value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); + Bind_(buffer->strides[k], value, stride_element_name(k), true); + stride = analyzer_.Simplify(stride * buffer->shape[k]); } } else { PrimExpr stride_from_shape = 1; - for (int k = pre_flattened_strides.size() - 1; k >= 0; k--) { - PrimExpr explicit_stride = cast(pre_flattened_shape[k].dtype(), - BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + for (int k = buffer->strides.size() - 1; k >= 0; k--) { + PrimExpr explicit_stride = + cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); - Bind_(pre_flattened_strides[k], + Bind_(buffer->strides[k], tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), stride_element_name(k), true); - stride_from_shape *= cast(pre_flattened_shape[k].dtype(), - BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); + stride_from_shape *= + cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); } } // Byte_offset field. diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 96ad060e5896..0c15f4af2fa2 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -51,10 +51,14 @@ PrimExpr BufferArea(const Buffer& buffer) { class BufferFlattener : public StmtExprMutator { public: static PrimFunc Flatten(PrimFunc func) { + Map preflattened_buffer_map = + Merge(func->buffer_map, func->preflattened_buffer_map); + auto pass = BufferFlattener(func->buffer_map); auto writer = func.CopyOnWrite(); writer->body = pass.VisitStmt(func->body); + writer->preflattened_buffer_map = preflattened_buffer_map; writer->buffer_map = pass.updated_extern_buffer_map_; return func; } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 8d8020f4d06c..368c84a98e2e 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -218,12 +218,14 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { continue; } - auto it = func_ptr->buffer_map.find(param); - if (it != func_ptr->buffer_map.end()) { - buffer_def.emplace_back(v_arg, (*it).second); + if (func_ptr->preflattened_buffer_map.count(param)) { + buffer_def.emplace_back(v_arg, func_ptr->preflattened_buffer_map[param]); + } else if (func_ptr->buffer_map.count(param)) { + buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]); } else { var_def.emplace_back(v_arg, param); } + if (i < num_packed_args) { // Value loads seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 4274733f095e..4f9530b93fda 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -179,12 +179,6 @@ class VarUseDefAnalysis : public StmtExprMutator { visit_arr(buffer->shape); visit_arr(buffer->strides); - if (buffer->pre_flattened_shape) { - visit_arr(buffer->pre_flattened_shape.value()); - } - if (buffer->pre_flattened_strides) { - visit_arr(buffer->pre_flattened_strides.value()); - } } void HandleDef(const VarNode* v) { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 2988129fdca8..859b735f7000 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1188,8 +1188,12 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); + Map preflattened_buffer_map = + Merge(func->buffer_map, func->preflattened_buffer_map); + auto fptr = func.CopyOnWrite(); fptr->body = pass(std::move(fptr->body)); + fptr->preflattened_buffer_map = preflattened_buffer_map; fptr->buffer_map = pass.UpdatedBufferMap(); return func; }; diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index e75610ea0551..90fe2ce61d49 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -97,8 +97,8 @@ IRModule PoolInfoAssigner::operator()() { if (kv.second->IsInstance()) { func_ = Downcast(kv.second); Stmt body = this->VisitStmt(func_->body); - PrimFunc new_prim_func = - PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs); + PrimFunc new_prim_func = PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, + func_->preflattened_buffer_map, func_->attrs); mod_->Update(gv, new_prim_func); } } diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index cd797681d474..5d267d1a5363 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -164,6 +164,9 @@ Optional PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( const PrimFunc& original_func) { + ICHECK_EQ(original_func->preflattened_buffer_map.size(), 0) + << "ConvertPoolAllocationsToOffsets pass expects to operate on pre-flattened buffers, prior " + "to StorageFlatten (TE schedules) or FlattenBuffers (TIR schedules)"; ScopeInfo si; Optional resource_handle = GetResourceHandle(original_func); @@ -216,8 +219,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = - PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); + PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, {}, + original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } @@ -340,12 +343,12 @@ IRModule PoolAllocationToOffsetConverter::operator()() { // We dont need attrs of PrimFunc that might include non printable attrs such as target // for unit tests where emit_tvmscript_printable_ is to be used. if (!emit_tvmscript_printable_) { - main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); + main_func = PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, + main_func->attrs); main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); } else { main_func = - PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, DictAttrs()); } module_->Update(gv, main_func); if (!emit_tvmscript_printable_) { diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index fabf41705698..40d17546470b 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -56,9 +56,9 @@ class LoweredModule: def main(a: T.handle, b: T.handle, c: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + A = T.match_buffer(a, [128, 128], flatten_buffer=True) + B = T.match_buffer(b, [128, 128], flatten_buffer=True) + C = T.match_buffer(c, [128, 128], flatten_buffer=True) # body for x, y in T.grid(128, 128): C.data[x * 128 + y] = 0.0 @@ -74,9 +74,9 @@ class LoweredTIRModule: def main(a: T.handle, b: T.handle, c: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + A = T.match_buffer(a, [128, 128], flatten_buffer=True) + B = T.match_buffer(b, [128, 128], flatten_buffer=True) + C = T.match_buffer(c, [128, 128], flatten_buffer=True) # body for x, y in T.grid(128, 128): C.data[x * 128 + y] = 0.0 From 918ea2d24e03299fd11810feabcfe1fc004a39fb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jan 2022 14:46:00 -0600 Subject: [PATCH 032/177] Updated ethos-u C++ unit tests to remove use of Load/Store. --- .../contrib/ethosu/tir/binary_elementwise.py | 6 +- .../backend/contrib/ethosu/tir/convolution.py | 12 +- .../backend/contrib/ethosu/tir/depthwise.py | 12 +- .../relay/backend/contrib/ethosu/tir/dma.py | 52 +++++---- .../backend/contrib/ethosu/tir/identity.py | 18 +-- .../backend/contrib/ethosu/tir/passes.py | 109 +++++++++++------- .../backend/contrib/ethosu/tir/pooling.py | 4 +- .../relay/backend/contrib/ethosu/tir/spec.py | 15 ++- .../backend/contrib/ethosu/tir/transform.py | 11 +- .../contrib/ethosu/tir/unary_elementwise.py | 6 +- .../relay/backend/contrib/ethosu/tir/utils.py | 28 ++--- .../contrib/ethosu/tir_to_cs_translator.py | 21 ++-- 12 files changed, 166 insertions(+), 128 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index 53b46aeafbf5..11e472070a6b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -77,12 +77,12 @@ def get_binary_elementwise_params( _, _, _, _, _, inner = get_outer_loops(body, "NHWC") op = ignore_cast(inner.value) - input_pointer = ignore_cast(op.a).buffer_var - input_pointer1 = ignore_cast(op.b).buffer_var + input_pointer = ignore_cast(op.a).buffer.data + input_pointer1 = ignore_cast(op.b).buffer.data if reversed_operands: input_pointer, input_pointer1 = input_pointer1, input_pointer - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) serial_ifm2, _ = get_ifm_params(input_pointer1, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 50c27cc01689..bdca6a874ca5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -59,8 +59,8 @@ def get_conv2d_params(stmt, producers, consumers): loads = get_loads(rc.body) # stores = [output] stores = get_stores(rc.body) - input_pointer = loads[1].buffer_var - output_pointer = stores[0].buffer_var + input_pointer = loads[1].buffer.data + output_pointer = stores[0].buffer.data # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) @@ -75,16 +75,16 @@ def get_conv2d_params(stmt, producers, consumers): ) # Get scale_bias info scale_bias_load = loads[3] - scale_bias_base = get_base_address(scale_bias_load.index) + scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices] serial_scale_bias = SerialAddressRange( - address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), length=SCALE_BIAS_LENGTH * serial_ofm[3], ) # Get weight info weight_load = loads[2] - weight_base = get_base_address(weight_load.index) + weight_base = [get_base_address(index) for index in weight_load.indices] serial_weight = SerialAddressRange( - address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent, ) # Get activation info diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py index b1a4ebd82a88..b39ec36e4231 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -68,8 +68,8 @@ def get_depthwise_conv2d_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer_var - output_pointer = stores[0].buffer_var + input_pointer = loads[1].buffer.data + output_pointer = stores[0].buffer.data # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) @@ -84,16 +84,16 @@ def get_depthwise_conv2d_params( ) # Get scale_bias info scale_bias_load = loads[3] - scale_bias_base = get_base_address(scale_bias_load.index) + scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices] serial_scale_bias = SerialAddressRange( - address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), length=SCALE_BIAS_LENGTH * serial_ofm[3], ) # Get weight info weight_load = loads[2] - weight_base = get_base_address(weight_load.index) + weight_base = [get_base_address(index) for index in weight_load.indices] serial_weight = SerialAddressRange( - address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1], ) # Get activation info diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 9f82d7478265..34ea9ef87c96 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -41,12 +41,12 @@ def get_pad_params(stmt): """ _, body = get_op_attrs(stmt) n, h, w, c, _, inner = get_outer_loops(body, "NHWC") - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data pad = SerialPadding(top=0, left=0, bottom=0, right=0) if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer_var + input_pointer = inner.value.args[1].buffer.data else: - input_pointer = inner.value.buffer_var + input_pointer = inner.value.buffer.data return pad, input_pointer, output_pointer padded_shape = [n.extent, h.extent, w.extent, c.extent] @@ -126,11 +126,11 @@ def get_convert_to_nhwc_params(stmt): # compute that is deemed uneccesary isn't removed by TVM. if attrs["layout"] == "NHCWB16": inner = inner.body - input_pointer = inner.value.b.buffer_var + input_pointer = inner.value.b.buffer.data else: - input_pointer = inner.value.buffer_var + input_pointer = inner.value.buffer.data - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data return c.extent, input_pointer, output_pointer @@ -154,13 +154,13 @@ def get_convert_to_nhcwb16_params(stmt): """ attrs, body = get_op_attrs(stmt) _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"]) - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data if isinstance(inner.value, tvm.tir.Call): cond = inner.value.args[0] out_channels = cond.b.value - input_pointer = inner.value.args[1].buffer_var + input_pointer = inner.value.args[1].buffer.data else: - input_pointer = inner.value.buffer_var + input_pointer = inner.value.buffer.data out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent return out_channels, input_pointer, output_pointer @@ -186,12 +186,17 @@ def get_read_params(stmt): """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.buffer.data + output_pointer = inner.buffer.data + + # Needed for stride calculation, can replace with + # inner.value.buffer.strides in future. + assert len(inner.value.indices) == 1, "Ethos-U DMA expects flattened buffers" stride_vars = [h.loop_var, w.loop_var, c.loop_var] - strides = get_strides(inner.value.index, stride_vars) - base_address = get_base_address(inner.value.index) - data_type = inner.buffer_var.type_annotation.element_type.dtype + strides = get_strides(inner.value.indices[0], stride_vars) + + base_address = [get_base_address(index) for index in inner.value.indices] + data_type = inner.buffer.data.type_annotation.element_type.dtype return ( SerialFeatureMap( data_type=data_type, @@ -201,7 +206,7 @@ def get_read_params(stmt): tile_height_0=h.extent, tile_height_1=0, tile_width_0=w.extent, - tile_address_0=tvm.tir.Load(data_type, inner.value.buffer_var, base_address), + tile_address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, @@ -237,12 +242,17 @@ def get_write_params(stmt): """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.buffer.data + output_pointer = inner.buffer.data + + # Needed for stride calculation, can replace with + # inner.value.buffer.strides in future. + assert len(inner.indices) == 1, "Ethos-U DMA expects flattened buffers" stride_vars = [h.loop_var, w.loop_var, c.loop_var] - strides = get_strides(inner.index, stride_vars) - base_address = get_base_address(inner.index) - data_type = inner.buffer_var.type_annotation.element_type.dtype + strides = get_strides(inner.indices[0], stride_vars) + + base_address = [get_base_address(index) for index in inner.indices] + data_type = inner.buffer.data.type_annotation.element_type.dtype return ( SerialFeatureMap( data_type=data_type, @@ -252,7 +262,7 @@ def get_write_params(stmt): tile_height_0=h.extent, tile_height_1=0, tile_width_0=w.extent, - tile_address_0=tvm.tir.Load(data_type, inner.buffer_var, base_address), + tile_address_0=tvm.tir.BufferLoad(inner.buffer, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 6dccb5a15c97..aacff55c451b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -59,12 +59,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur fm_inner = inner.value if fm_type == "ifm" else inner + # Needed for stride calculation, can replace with + # inner.value.buffer.strides in future. + assert len(fm_inner.indices) == 1, "Ethos-U passes expect flattened buffers" stride_vars = [l.loop_var for l in loops] - strides = get_strides(fm_inner.index, stride_vars) + strides = get_strides(fm_inner.indices[0], stride_vars) - base_address = get_base_address(fm_inner.index) - data_type = inner.buffer_var.type_annotation.element_type.dtype - pointer = fm_inner.buffer_var + base_address = [get_base_address(index) for index in fm_inner.index] + data_type = inner.buffer.data.type_annotation.element_type.dtype serial_feature_map = SerialFeatureMap( data_type=data_type, @@ -74,7 +76,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur tile_height_0=loops[0].extent, tile_height_1=0, tile_width_0=loops[1].extent if len(loops) > 1 else 1, - tile_address_0=tvm.tir.Load(data_type, pointer, base_address), + tile_address_0=tvm.tir.BufferLoad(fm_inner, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, @@ -86,7 +88,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur stride_c=strides[2] if len(strides) > 2 else 1, ) - output_pointer = inner.buffer_var + output_pointer = inner.buffer.data return serial_feature_map, output_pointer @@ -130,8 +132,8 @@ def get_identity_params( # loads = [input, LUT, LUT] loads = get_loads(stmt) - input_pointer = loads[0].buffer_var - output_pointer = stmt.buffer_var + input_pointer = loads[0].buffer.data + output_pointer = stmt.buffer.data read = producers[input_pointer] write = consumers[output_pointer] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index c2fff8abb9b0..62a2e01f37e8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -28,7 +28,7 @@ from .identity import get_identity_params from .unary_elementwise import get_unary_elementwise_params from .transform import get_copy_params -from .utils import get_weights_pointer, get_scale_bias_pointer +from .utils import get_weights_buffer, get_scale_bias_buffer def RemoveZeroStores(): @@ -82,8 +82,8 @@ def _resolve_pointers(stmt): loads = [] def _get_loads(stmt): - if isinstance(stmt, tvm.tir.Load): - loads.append(stmt.buffer_var) + if isinstance(stmt, tvm.tir.BufferLoad): + loads.append(stmt.buffer.data) if isinstance(stmt, tvm.tir.Allocate): pointer_to_extents[stmt.buffer_var] = stmt.extents @@ -94,8 +94,8 @@ def _get_loads(stmt): elif isinstance(stmt, tvm.tir.AttrStmt): if stmt.attr_key == "pragma_op": tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) - for load_buffer in loads: - pointer_to_consumer[load_buffer] = stmt + for load_pointer in loads: + pointer_to_consumer[load_pointer] = stmt def _replace_operator(stmt): """Replace operators with call_externs, having derived the parameters @@ -232,11 +232,14 @@ def DivideConstants(const_dict): def _visit(stmt): new_args = [] for i, arg in enumerate(stmt.args): - if isinstance(arg, tvm.tir.expr.Load): + if isinstance(arg, tvm.tir.expr.BufferLoad): # If we're trying to load a buffer that maps to a constant - if arg.buffer_var in buffer_to_const: - const = buffer_to_const[arg.buffer_var] - offset = int(arg.index) + if arg.buffer.data in buffer_to_const: + const = buffer_to_const[arg.buffer.data] + + assert len(arg.indices) == 1, "Ethos-U passes expects flattened buffers" + + offset = int(arg.indices[0]) # Note by convention the arg after a constant read is the length of the read length = int(stmt.args[i + 1]) # If it's anything other than a full read, create a new buffer @@ -244,9 +247,9 @@ def _visit(stmt): new_consts.append(const[offset : offset + length]) new_buffer = tvm.tir.decl_buffer((length,), arg.dtype) new_buffers.append(new_buffer) - new_args.append(tvm.tir.expr.Load(new_buffer.dtype, new_buffer.data, 0)) + new_args.append(tvm.tir.expr.BufferLoad(new_buffer.data, [0])) continue - keep_buffers.add(arg.buffer_var) + keep_buffers.add(arg.buffer.data) new_args.append(arg) @@ -278,7 +281,15 @@ def _ftransform(f, mod, ctx): new_buffer_map[handle] = new_buffer new_const_dict[len(new_params) - 1] = new_consts[i] - new_f = tvm.tir.PrimFunc(new_params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + new_f = tvm.tir.PrimFunc( + new_params, + new_body, + f.ret_type, + new_buffer_map, + f.preflattened_buffer_map, + f.attrs, + f.span, + ) return new_f def _divide_constants(mod): @@ -343,30 +354,31 @@ def _visit_encode_pre(stmt): # Handle copies as a special-case by propagating the buffer information # from the read to the write pointer. if stmt.args[0] == "ethosu_copy": - read_pointer = stmt.args[1].buffer_var + read_pointer = stmt.args[1].buffer.data if read_pointer in pointer_to_buffer: - write_pointer = stmt.args[3].buffer_var + write_pointer = stmt.args[3].buffer.data # Assert writing to the base of the write_var (pre-StorageRewrite) - assert stmt.args[3].index == 0 - assert stmt.args[1].index == 0 + assert list(stmt.args[3].indices) == [0] + assert list(stmt.args[1].indices) == [0] pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer] + rewrite_buffer[stmt.args[3].buffer] = stmt.args[1].buffer else: # Encode the weights - weights_pointer = get_weights_pointer(stmt) - if weights_pointer is not None: - assert weights_pointer in pointer_to_buffer - weights_buffer = pointer_to_buffer[weights_pointer] - weights_value = buffer_to_const[weights_buffer] + old_weights_buffer = get_weights_buffer(stmt) + if old_weights_buffer is not None: + assert old_weights_buffer.data in pointer_to_buffer + new_weights_buffer = pointer_to_buffer[old_weights_buffer.data] + weights_value = buffer_to_const[new_weights_buffer] new_weights_value = _encode_weights(stmt, weights_value) - _new_buffer(weights_buffer, new_weights_value) + _new_buffer(new_weights_buffer, new_weights_value) # Align the scale_bias to 16 bytes - scale_bias_pointer = get_scale_bias_pointer(stmt) - if scale_bias_pointer is not None: - assert scale_bias_pointer in pointer_to_buffer - scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] - scale_bias_value = buffer_to_const[scale_bias_buffer] + old_scale_bias_buffer = get_scale_bias_buffer(stmt) + if old_scale_bias_buffer is not None: + assert old_scale_bias_buffer.data in pointer_to_buffer + new_scale_bias_buffer = pointer_to_buffer[old_scale_bias_buffer.data] + scale_bias_value = buffer_to_const[new_scale_bias_buffer] new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) - _new_buffer(scale_bias_buffer, new_scale_bias_value) + _new_buffer(new_scale_bias_buffer, new_scale_bias_value) def _visit_encode_post(stmt): # Because encoding may change the data type (e.g. bias to uint8) and type information @@ -398,14 +410,14 @@ def _visit_rewrite(stmt): new_buffers = rewrite_buffer.values() for i in range(1, len(stmt.args)): # If the previous argument was a load, the current should be a length - if isinstance(stmt.args[i - 1], tvm.tir.Load): + if isinstance(stmt.args[i - 1], tvm.tir.BufferLoad): load = stmt.args[i - 1] - pointer = load.buffer_var - if pointer in pointer_to_buffer: - buffer = pointer_to_buffer[pointer] + old_buffer = load.buffer + if old_buffer.data in pointer_to_buffer: + new_buffer = pointer_to_buffer[old_buffer.data] # Only rewrite the arguments of buffers that have been encoded - if buffer in new_buffers: - new_arg = np.prod(list(pointer_to_buffer[pointer].shape)) + if new_buffer in new_buffers: + new_arg = np.prod(list(new_buffer.shape)) new_args.append(new_arg) continue new_args.append(stmt.args[i]) @@ -429,14 +441,12 @@ def _visit_rewrite(stmt): # The following rewrites would be better expressed by just rewriting the Vars, however # ir_transform doesn't seem to visit Vars. So instead we do the next best thing and rewrite # the nodes which contain the Vars. - if isinstance(stmt, tvm.tir.Load): - load_pointer = stmt.buffer_var - if load_pointer in rewrite_pointer: - new_pointer = rewrite_pointer[load_pointer] - element_type = new_pointer.type_annotation.element_type.dtype - return tvm.tir.Load( - element_type, new_pointer, stmt.index, stmt.predicate, stmt.span - ) + if isinstance(stmt, tvm.tir.BufferLoad): + if stmt.buffer.data in pointer_to_buffer: + load_buffer = pointer_to_buffer[stmt.buffer.data] + if load_buffer in rewrite_buffer: + new_buffer = rewrite_buffer[load_buffer] + return tvm.tir.BufferLoad(new_buffer, stmt.indices, stmt.span) if isinstance(stmt, tvm.tir.AttrStmt): node_pointer = stmt.node if node_pointer in rewrite_pointer: @@ -457,7 +467,10 @@ def _ftransform(f, mod, ctx): ) # Then perform the rewrites new_body = tvm.tir.stmt_functor.ir_transform( - f.body, None, _visit_rewrite, ["tir.Call", "tir.Allocate", "tir.Load", "tir.AttrStmt"] + f.body, + None, + _visit_rewrite, + ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], ) new_buffer_map = {} # Rewrite the buffer map and const dict to instead use the encoded versions @@ -474,7 +487,15 @@ def _ftransform(f, mod, ctx): else: new_buffer_map[param] = buffer - new_f = tvm.tir.PrimFunc(f.params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + new_f = tvm.tir.PrimFunc( + f.params, + new_body, + f.ret_type, + new_buffer_map, + f.preflattened_buffer_map, + f.attrs, + f.span, + ) return new_f def _encode_constants(mod): diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index e929caa2409b..3b32ef01a938 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -61,8 +61,8 @@ def get_pooling_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer_var - output_pointer = stores[0].buffer_var + input_pointer = loads[1].buffer.data + output_pointer = stores[0].buffer.data # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index f9d38df9d901..d390fc0e10dc 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -93,10 +93,10 @@ def __init__( tile_height_0: int, tile_height_1: int, tile_width_0: int, - tile_address_0: tvm.tir.expr.Load, - tile_address_1: Union[tvm.tir.expr.Load, int], - tile_address_2: Union[tvm.tir.expr.Load, int], - tile_address_3: Union[tvm.tir.expr.Load, int], + tile_address_0: tvm.tir.expr.BufferLoad, + tile_address_1: Union[tvm.tir.expr.BufferLoad, int], + tile_address_2: Union[tvm.tir.expr.BufferLoad, int], + tile_address_3: Union[tvm.tir.expr.BufferLoad, int], scale: float, zero_point: int, layout: str, @@ -148,7 +148,7 @@ class SerialAddressRange(SerializableFormat): """Specialization class to retrieve arguments of a AddressRange (similiar to NpuAddressRange of Vela) on a predefined ordering""" - def __init__(self, address: tvm.tir.expr.Load, length: int): + def __init__(self, address: tvm.tir.expr.BufferLoad, length: int): self.address = address self.length = length @@ -237,7 +237,10 @@ class SerialCopy(SerializableFormat): a ethosu.copy tir extern call on a predefined ordering""" def __init__( - self, read_address: tvm.tir.expr.Load, length: int, write_address: tvm.tir.expr.Load + self, + read_address: tvm.tir.expr.BufferLoad, + length: int, + write_address: tvm.tir.expr.BufferLoad, ): self.read_address = read_address self.length = length diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 141505a3dfba..53e0bd2a728b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -50,17 +50,16 @@ def get_copy_params(stmt, producers, consumers): _, body = get_op_attrs(stmt) length = body.extent write_store = body.body - write_base = get_base_address(write_store.index) + write_base = [get_base_address(index) for index in write_store.indices] read_load = body.body.value - read_base = get_base_address(read_load.index) - dtype = body.body.value.dtype + read_base = [get_base_address(index) for index in read_load.indices] return ( SerialCopy( - read_address=tvm.tir.expr.Load(dtype, read_load.buffer_var, read_base), + read_address=tvm.tir.expr.BufferLoad(read_load.buffer, read_base), length=length, - write_address=tvm.tir.expr.Load(dtype, write_store.buffer_var, write_base), + write_address=tvm.tir.expr.BufferLoad(write_store.buffer, write_base), ), - write_store.buffer_var, + write_store.buffer.data, None, True, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index b550b79e7906..9c570d88c163 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -54,11 +54,11 @@ def get_unary_elementwise_params(stmt, producers, consumers): input_pointer = None if isinstance(inner.value, tir.expr.Select): # ABS - input_pointer = inner.value.condition.b.buffer_var + input_pointer = inner.value.condition.b.buffer.data if isinstance(inner.value, tir.expr.Sub): # CLZ - input_pointer = inner.value.b.args[0].buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.b.args[0].buffer.data + output_pointer = inner.buffer.data # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index de1c0ab19f6e..506f18ba3a99 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -21,20 +21,20 @@ # TODO(@mbaret): Formalise this with a specification -def get_weights_pointer(tir_extern_call): +def get_weights_buffer(tir_extern_call): """Get the weights pointer from a NPU extern call if it exists""" supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[41].buffer_var + return tir_extern_call.args[41].buffer return None # TODO(@mbaret): Formalise this with a specification -def get_scale_bias_pointer(tir_extern_call): +def get_scale_bias_buffer(tir_extern_call): """Get the scale_bias pointer from a NPU extern call if it exists""" supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[44].buffer_var + return tir_extern_call.args[44].buffer return None @@ -177,23 +177,23 @@ def get_outer_loops(stmt, layout): def get_loads(stmt): - """Get the Load statements. + """Get the BufferLoad statements. Parameters ---------- stmt : tvm.tir.Stmt - The statement to get the Loads from. + The statement to get the BufferLoads from. Returns ------- - loads : list of tvm.tir.Load - The Loads found. + loads : list of tvm.tir.BufferLoad + The BufferLoads found. """ loads = [] def _visit(s): - if isinstance(s, tvm.tir.Load): + if isinstance(s, tvm.tir.BufferLoad): loads.append(s) tvm.tir.stmt_functor.post_order_visit(stmt, _visit) @@ -201,23 +201,23 @@ def _visit(s): def get_stores(stmt): - """Get the Store statements. + """Get the BufferStore statements. Parameters ---------- stmt : tvm.tir.Stmt - The statement to get the Stores from. + The statement to get the BufferStores from. Returns ------- - stores : list of tvm.tir.Store - The Stores found. + stores : list of tvm.tir.BufferStore + The BufferStores found. """ stores = [] def _visit(s): - if isinstance(s, tvm.tir.Store): + if isinstance(s, tvm.tir.BufferStore): stores.append(s) tvm.tir.stmt_functor.post_order_visit(stmt, _visit) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index c55a6310ffa5..63c8e8255726 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -271,7 +271,7 @@ def assign_addresses(buffer_info, npu_ops): This is the dictionary obtained via calling extract_buffer_info. The key is the buffer name to BufferInfo npu_ops : list - A list of Vela NpuOps with tir.Loads for addresses + A list of Vela NpuOps with tir.BufferLoads for addresses Returns ------- npu_ops : list @@ -283,16 +283,19 @@ def assign_addresses(buffer_info, npu_ops): """ def replace_npu_fm_with_address(npu_fm): - assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.Load) + assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.BufferLoad) # We currently does not support tiles # Change this when tiles are needed # (i.e. when using rolling buffers) assert npu_fm.tiles.addresses[1:] == [0, 0, 0] npu_fm.tiles.addresses[1:] = [0, 0, 0] - buffer = npu_fm.tiles.addresses[0].buffer_var + buffer = npu_fm.tiles.addresses[0].buffer.data assert buffer in buffer_addresses.keys() address, buffer_type = buffer_addresses[buffer] - index = npu_fm.tiles.addresses[0].index * ( + assert ( + len(npu_fm.tiles.addresses[0].indices) == 1 + ), "Ethos-U translation expects flattened buffers" + index = npu_fm.tiles.addresses[0].indices[0] * ( np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 ) npu_fm.tiles.addresses[0] = address + int(index) @@ -300,8 +303,8 @@ def replace_npu_fm_with_address(npu_fm): return npu_fm def replace_npu_address_range_with_address(npu_addr_range): - assert isinstance(npu_addr_range.address, tvm.tir.Load) - buffer = npu_addr_range.address.buffer_var + assert isinstance(npu_addr_range.address, tvm.tir.BufferLoad) + buffer = npu_addr_range.address.buffer.data assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) @@ -316,11 +319,11 @@ def replace_tir_loads(npu_object): def classify_io(buffer): for _npu_op in npu_ops: if issubclass(type(_npu_op), vapi.NpuBlockOperation): - if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer_var == buffer: + if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer.data == buffer: return BufferType.input - if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer_var == buffer: + if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer.data == buffer: return BufferType.input - if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer_var == buffer: + if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer.data == buffer: return BufferType.output raise ValueError(f"Unused IO : {buffer} in tir module.") From 0ffe060ff066e2238527bdf61e406735300ac81b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 27 Jan 2022 09:36:51 -0600 Subject: [PATCH 033/177] Bugfix, layout transformation. Error occured during conversion from TE to IRModule, when layout transforms were applied to a reader of a `cache_read`. --- src/te/schedule/schedule_ops.cc | 17 ++++++ .../python/unittest/test_transform_layout.py | 59 +++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 368121c74bc0..75736d0333da 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -233,6 +233,23 @@ class SchedulePostProc : public StmtExprMutator { return this->VisitStmt(op->body); } } + } else if (op->attr_key == tir::attr::layout_transforms || + op->attr_key == tir::attr::axis_separators) { + auto arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2); + + Stmt body = op->body; + + Tensor tensor = Downcast(arr[0]); + auto it = replace_op_.find(tensor->op.get()); + if (it != replace_op_.end()) { + if (it->second.defined()) { + return AttrStmt(Array{it->second.output(tensor->value_index), arr[1]}, + op->attr_key, op->value, this->VisitStmt(op->body)); + } else { + return this->VisitStmt(op->body); + } + } } return StmtExprMutator::VisitStmt_(op); } diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index fb85463f1cb8..c70be6f782eb 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -416,5 +416,64 @@ def test_use_transformed_axes( self.compare_tir_loop_order(func.body, expected_loop_order) +class TestTransformCache: + A_size = tvm.testing.parameter(16) + + transform_A = tvm.testing.parameter(by_dict={"transformA": True, "": False}) + transform_B = tvm.testing.parameter(by_dict={"transformB": True, "": False}) + cache_A = tvm.testing.parameter(by_dict={"cacheA": True, "": False}) + cache_B = tvm.testing.parameter(by_dict={"cacheB": True, "": False}) + + @tvm.testing.fixture + def schedule_args(self, A_size, transform_A, transform_B, cache_A, cache_B, dtype): + A = te.placeholder(shape=[A_size], dtype=dtype, name="A") + B = te.compute(A.shape, lambda i: A[i], name="B") + s = te.create_schedule(B.op) + + if transform_A: + A_axis = s[A].transform_layout(lambda i: [i // 4, i % 4]) + + if transform_B: + B_axis = s[B].transform_layout(lambda i: [i // 4, i % 4]) + + if cache_A: + AA = s.cache_read(A, "shared", [B]) + + if cache_B: + BB = s.cache_write(B, "shared") + + return [s, [A, B]] + + @tvm.testing.fixture + def ref_data(self, A_size, dtype, transform_A, transform_B): + a_np = (100 * np.random.uniform(size=A_size)).astype(dtype) + b_np = a_np + + if transform_A: + a_np = a_np.reshape((-1, 4)) + + if transform_B: + b_np = b_np.reshape((-1, 4)) + + return a_np, b_np + + def test_lower(self, schedule_args): + tvm.lower(*schedule_args) + + def test_execute(self, target, dev, schedule_args, ref_data, dtype): + func = tvm.build(*schedule_args, target=target) + + a_np, b_np = ref_data + a = tvm.nd.array(a_np, dev) + b = tvm.nd.empty(b_np.shape, dtype=dtype, device=dev) + + func(a, b) + + if "int" in dtype: + np.testing.assert_equal(b.numpy(), b_np) + else: + tvm.testing.assert_allclose(b.numpy(), b_np) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From 6e0ca38f1dc9b4ac0b4141a8e5917c8d3c8c2beb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jan 2022 12:22:21 -0600 Subject: [PATCH 034/177] In test directory, replacing all instances of T.load. --- .../test_ethosu/test_encode_constants.py | 114 ++--- .../test_ethosu/test_replace_conv2d.py | 112 ++--- .../contrib/test_ethosu/test_replace_copy.py | 30 +- .../contrib/test_ethosu/test_scheduler.py | 22 +- .../test_ethosu/test_tir_to_cs_translator.py | 166 +++---- .../contrib/test_ethosu/test_vela_api.py | 16 +- tests/python/unittest/test_lower_build.py | 12 +- .../unittest/test_target_codegen_llvm.py | 2 +- .../test_tir_analysis_calculate_workspace.py | 30 +- ...t_tir_analysis_detect_buffer_access_lca.py | 2 +- tests/python/unittest/test_tir_intrin.py | 5 +- .../unittest/test_tir_lower_match_buffer.py | 2 +- .../test_tir_schedule_cache_read_write.py | 6 +- .../test_tir_schedule_compute_inline.py | 4 +- ...est_tir_transform_compact_buffer_region.py | 4 +- ..._tir_transform_convert_for_loops_serial.py | 6 +- .../test_tir_transform_flatten_buffer.py | 26 +- .../test_tir_transform_loop_partition.py | 4 +- tests/python/unittest/test_tir_usmp_algo.py | 38 +- ...st_tir_usmp_analysis_extract_bufferinfo.py | 184 ++++---- ...orm_convert_pool_allocations_to_offsets.py | 110 ++--- tests/python/unittest/test_tir_usmp_utils.py | 12 +- .../unittest/test_tvmscript_error_report.py | 4 +- .../unittest/test_tvmscript_roundtrip.py | 414 +++++++----------- 24 files changed, 612 insertions(+), 713 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 315712996ac8..7bf0c9a181aa 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -37,29 +37,29 @@ class WeightStreamOnly: def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") + buffer_4 = T.buffer_decl([], "uint8") + buffer_5 = T.buffer_decl([], "uint8") + buffer_6 = T.buffer_decl([], "uint8") + buffer_7 = T.buffer_decl([], "uint8") # body placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -110,17 +110,17 @@ class RereadWeights: def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -171,14 +171,14 @@ class DirectReadOnly: def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 160, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -228,33 +228,33 @@ class MixedRead: def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") - buffer_8 = T.buffer_var("uint8", "") - buffer_9 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") + buffer_4 = T.buffer_decl([], "uint8") + buffer_5 = T.buffer_decl([], "uint8") + buffer_6 = T.buffer_decl([], "uint8") + buffer_7 = T.buffer_decl([], "uint8") + buffer_8 = T.buffer_decl([], "uint8") + buffer_9 = T.buffer_decl([], "uint8") # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 67fb2c760962..cc11e21e94e0 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -336,16 +336,16 @@ class Conv2dDoubleCascade1: def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, 12, buffer_1[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[32], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, 12, buffer_1[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -355,16 +355,16 @@ class Conv2dDoubleCascade2: def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, 12, buffer[0], 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, 12, buffer[0], 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -374,18 +374,18 @@ class Conv2dDoubleCascade3: def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, placeholder_5[576], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, ethosu_write_1[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -395,16 +395,16 @@ class Conv2dDoubleCascade4: def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, 12, buffer_2[0], 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[1024], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, 12, buffer_2[0], 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -594,10 +594,10 @@ class Conv2dInlineCopy1: def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 848, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -607,10 +607,10 @@ class Conv2dInlineCopy2: def main(placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 656, 12, T.load("uint8", buffer, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, 12, buffer[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -649,11 +649,11 @@ class Conv2dInlineReshape1: def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -663,11 +663,11 @@ class Conv2dInlineReshape2: def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -677,11 +677,11 @@ class Conv2dInlineReshape3: def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -691,11 +691,11 @@ class Conv2dInlineReshape4: def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 7aee57d548fe..afe7aa81c73f 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -34,14 +34,14 @@ class ReferenceModule: def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 304, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 80, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -78,19 +78,19 @@ class WeightStream: def main(placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 416, 12, placeholder_d_global[0], 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 272, 12, placeholder_d_global[0], 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 6a4aba4e38fc..ab2e3942582e 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -182,23 +182,23 @@ class DiamondGraphTir: @T.prim_func def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffer[(1, 56, 56, 24), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - weight_buffer = T.buffer_var("uint8", "") - bias_buffer = T.buffer_var("uint8", "") - weight_buffer2 = T.buffer_var("uint8", "") - bias_buffer2 = T.buffer_var("uint8", "") + weight_buffer = T.buffer_decl([], "uint8") + bias_buffer = T.buffer_decl([], "uint8") + weight_buffer2 = T.buffer_decl([], "uint8") + bias_buffer2 = T.buffer_decl([], "uint8") placeholder_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) featuremap_buffer = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) featuremap_buffer2 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer, 0), 2608, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, T.load("int8", input_buffer.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 2608, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer2, 0), 736, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer2, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 736, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", output_buffer.data, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", weight_buffer[0], 2608, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", bias_buffer[0], 240, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer.data[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 2608, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", weight_buffer2[0], 736, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", bias_buffer2[0], 240, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 736, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, output_buffer.data[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index add8021083c6..fafc4d84ea5b 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -36,10 +36,10 @@ class SingleEthosUConv2D: def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.buffer_var("uint8", "") - placeholder_5 = T.buffer_var("uint8", "") + placeholder_4 = T.buffer_decl([], "uint8") + placeholder_5 = T.buffer_decl([], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_4, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_4[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -51,17 +51,17 @@ class MultiEthosUConv2D: def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_9 = T.buffer_var("uint8", "") - placeholder_7 = T.buffer_var("uint8", "") - placeholder_8 = T.buffer_var("uint8", "") - placeholder_5 = T.buffer_var("uint8", "") + placeholder_9 = T.buffer_decl([], "uint8") + placeholder_7 = T.buffer_decl([], "uint8") + placeholder_8 = T.buffer_decl([], "uint8") + placeholder_5 = T.buffer_decl([], "uint8") # body ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, 12, placeholder_8[0], 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, 12, placeholder_8[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -73,14 +73,14 @@ class MultiEthosUCopy: def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_5 = T.buffer_var("uint8", "") - placeholder_4 = T.buffer_var("uint8", "") + placeholder_5 = T.buffer_decl([], "uint8") + placeholder_4 = T.buffer_decl([], "uint8") # body placeholder_global = T.allocate([256], "uint8", "global") placeholder_d_global = T.allocate([8], "int32", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, 12, T.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_4[0], 256, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_5[0], 8, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 0, 12, placeholder_d_global[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) # fmt: on @@ -90,14 +90,14 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu class WeightStreamOnly: @T.prim_func def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") + buffer_4 = T.buffer_decl([], "uint8") + buffer_5 = T.buffer_decl([], "uint8") + buffer_6 = T.buffer_decl([], "uint8") + buffer_7 = T.buffer_decl([], "uint8") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -112,18 +112,18 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ # body placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -134,16 +134,16 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ class MixedRead: @T.prim_func def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") - buffer_8 = T.buffer_var("uint8", "") - buffer_9 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") + buffer_4 = T.buffer_decl([], "uint8") + buffer_5 = T.buffer_decl([], "uint8") + buffer_6 = T.buffer_decl([], "uint8") + buffer_7 = T.buffer_decl([], "uint8") + buffer_8 = T.buffer_decl([], "uint8") + buffer_9 = T.buffer_decl([], "uint8") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -161,19 +161,19 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -528,7 +528,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [1, 6, 7, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, T.load("int8", ethosu_depthwise_conv2d_1.data, 0), 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, T.load("int8", placeholder_4.data, 0), 18, 13, T.load("uint8", placeholder_5.data, 0), 30, 0, 0, 0, 0, "CLIP", 15, 105, "TFL", "NONE", dtype="int8")) + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, ethosu_depthwise_conv2d_1[0], 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, placeholder_4[0], 18, 13, placeholder_5[0], 30, 0, 0, 0, 0, "CLIP", 15, 105, "TFL", "NONE", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -666,9 +666,9 @@ def populate_ethosu_copy_calls(stmt): class MixedConstantDatatypes: @T.prim_func def main(placeholder_4: T.Buffer[(1, 8, 16, 16), "int8"], ethosu_write_1: T.Buffer[(1, 1, 1, 16), "int8"]) -> None: - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("int16", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "int16") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -680,11 +680,11 @@ def main(placeholder_4: T.Buffer[(1, 8, 16, 16), "int8"], ethosu_write_1: T.Buff placeholder_d_global = T.allocate([160], "uint8", "global") ethosu_write_2 = T.allocate([16], "int16", "global") placeholder_d_global_1 = T.allocate([1], "int16", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 272, T.load("uint8", placeholder_global, 0), dtype="uint8")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 160, T.load("uint8", placeholder_d_global, 0), dtype="uint8")) - T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 0, T.load("uint8", placeholder_d_global, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16")) - T.evaluate(T.call_extern("ethosu_copy", T.load("int16", buffer_2, 0), 1, T.load("int16", placeholder_d_global_1, 0), dtype="int16")) - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, T.load("int16", placeholder_d_global_1, 0), 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 272, placeholder_global[0], dtype="uint8")) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 160, placeholder_d_global[0], dtype="uint8")) + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, placeholder_4[0], 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, ethosu_write_2[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, placeholder_global[0], 272, 0, placeholder_d_global[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 1, placeholder_d_global_1[0], dtype="int16")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, ethosu_write_2[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, placeholder_d_global_1[0], 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, ethosu_write_1[0], 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL", dtype="int8")) # fmt: on @@ -964,7 +964,7 @@ def main(placeholder: T.handle, placeholder_3: T.handle, ethosu_write: T.handle) placeholder_4 = T.match_buffer(placeholder, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 5, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "TFL", "NONE", dtype="int8")) + T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, placeholder_4[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "TFL", "NONE", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1041,7 +1041,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1 ) # body - T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1057,7 +1057,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1072,7 +1072,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1088,7 +1088,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1104,7 +1104,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1120,7 +1120,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1136,7 +1136,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1257,7 +1257,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1272,7 +1272,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1287,7 +1287,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1303,7 +1303,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1319,7 +1319,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, "TFL", dtype="int8")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None # fmt: on @@ -1335,7 +1335,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1351,7 +1351,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, "TFL", dtype="int32")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, "TFL", dtype="int32")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index af75dc82a0bb..be75ff9b827a 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -72,7 +72,7 @@ def main( 8, 0, 8, - T.load("uint8", placeholder_3.data, 0), + placeholder_3.data[0], 0, 0, 0, @@ -89,7 +89,7 @@ def main( 8, 0, 8, - T.load("uint8", ethosu_conv2d_1.data, 0), + ethosu_conv2d_1.data[0], 0, 0, 0, @@ -105,10 +105,10 @@ def main( 1, 1, 1, - T.load("uint8", placeholder_4.data, 0), + placeholder_4.data[0], 0, 12, - T.load("uint8", placeholder_5.data, 0), + placeholder_5.data[0], 0, 0, 0, @@ -168,7 +168,7 @@ def main( 8, 0, 8, - T.load("uint8", placeholder_3.data, 0), + placeholder_3.data[0], 0, 0, 0, @@ -185,7 +185,7 @@ def main( 8, 0, 8, - T.load("uint8", ethosu_conv2d_1.data, 0), + ethosu_conv2d_1.data[0], 0, 0, 0, @@ -201,10 +201,10 @@ def main( 1, 1, 1, - T.load("uint8", placeholder_4.data, 0), + placeholder_4.data[0], 0, 12, - T.load("uint8", placeholder_5.data, 0), + placeholder_5.data[0], 0, 0, 0, diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 40d17546470b..326554e90e5a 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -61,11 +61,9 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], flatten_buffer=True) # body for x, y in T.grid(128, 128): - C.data[x * 128 + y] = 0.0 + C[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C.data[x * 128 + y] = T.load("float32", C.data, x * 128 + y) + T.load( - "float32", A.data, x * 128 + k - ) * T.load("float32", B.data, y * 128 + k) + C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] @tvm.script.ir_module @@ -79,11 +77,9 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], flatten_buffer=True) # body for x, y in T.grid(128, 128): - C.data[x * 128 + y] = 0.0 + C[x * 128 + y] = 0.0 for k in T.serial(0, 128): - C.data[x * 128 + y] = T.load("float32", C.data, x * 128 + y) + T.load( - "float32", A.data, x * 128 + k - ) * T.load("float32", B.data, y * 128 + k) + C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k] def test_lower_build_te_schedule(): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 0e303aaff6eb..c09161d67ce6 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -925,7 +925,7 @@ def threadpool_nested_parallel_loop( T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): - T.store(B.data, i * 4 + j, T.load("float32", A.data, i * 4 + j) * 2.0) + T.store(B.data, i * 4 + j, A.data[i * 4 + j] * 2.0) with pytest.raises(tvm.TVMError) as e: tvm.build({"llvm": tvm.IRModule.from_expr(threadpool_nested_parallel_loop)}) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 4b61625014e2..89e0791e457d 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -34,21 +34,21 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand PaddedInput_22 = T.allocate([131072], "int16", "global") DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): - PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), T.load("int16", placeholder_147.data, ((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)), T.int16(0), dtype="int16") + PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), placeholder_147[((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)], T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = 0 for di_9, dj_9 in T.grid(3, 3): - DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (T.load("int32", DepthwiseConv2d_9, (((i_9*7168) + (j_9*512)) + c_9)) + (T.load("int16", PaddedInput_22, (((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)).astype("int32")*T.load("int16", placeholder_148.data, (((di_9*1536) + (dj_9*512)) + c_9)).astype("int32"))) + DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] + (PaddedInput_22[(((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)].astype("int32")*placeholder_148[(((di_9*1536) + (dj_9*512)) + c_9)].astype("int32"))) for ax1_27, ax2_28, ax3_30 in T.grid(14, 14, 512): - DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (T.load("int32", DepthwiseConv2d_9, (((ax1_27*7168) + (ax2_28*512)) + ax3_30)) + T.load("int32", placeholder_149.data, ax3_30)) + DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] + placeholder_149[ax3_30]) for i1_30, i2_40, i3_41 in T.grid(14, 14, 512): - DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(T.load("int32", DepthwiseConv2d_9, (((i1_30*7168) + (i2_40*512)) + i3_41)), 1269068532, 31, -4, dtype="int32") + DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)], 1269068532, 31, -4, dtype="int32") for i1_31, i2_41, i3_42 in T.grid(14, 14, 512): - DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(T.load("int32", DepthwiseConv2d_9, (((i1_31*7168) + (i2_41*512)) + i3_42)), 255), 0) + DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)], 255), 0) for ax1_28, ax2_29, ax3_31 in T.grid(14, 14, 512): - PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = T.load("int32", DepthwiseConv2d_9, (((ax1_28*7168) + (ax2_29*512)) + ax3_31)).astype("uint8") + PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = DepthwiseConv2d_9[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)].astype("uint8") for ax1_29, ax2_30, ax3_32 in T.grid(14, 14, 512): - T_cast_49.data[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = T.load("uint8", PaddedInput_22, (((ax1_29*7168) + (ax2_30*512)) + ax3_32)).astype("int16") + T_cast_49[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = PaddedInput_22[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)].astype("int16") # fmt: on @@ -64,29 +64,29 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl # body PaddedInput_25 = T.allocate([1, 16, 16, 512], "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): - PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), T.load("int16", placeholder_165.data, ((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)), T.int16(0), dtype="int16") + PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16") T_add_11 = T.allocate([1, 14, 14, 512], "int32", "global") with T.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): - DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (T.load("int32", DepthwiseConv2d_11, (((i_11*7168) + (j_11*512)) + c_11)) + (T.load("int16", PaddedInput_25, (((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)).astype("int32")*T.load("int16", placeholder_166.data, (((di_11*1536) + (dj_11*512)) + c_11)).astype("int32"))) + DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): - T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (T.load("int32", DepthwiseConv2d_11, (((ax1_44*7168) + (ax2_45*512)) + ax3_47)) + T.load("int32", placeholder_167.data, ax3_47)) + T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47]) compute_22 = T.allocate([1, 14, 14, 512], "int32", "global") with T.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): - T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T.load("int32", T_add_11, (((ax1_45*7168) + (ax2_46*512)) + ax3_48)) + T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): - compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T.load("int32", T_cast_78, (((i1_36*7168) + (i2_47*512)) + i3_48)), 1948805937, 31, -5, dtype="int32") + compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32") T_cast_79 = T.allocate([1, 14, 14, 512], "uint8", "global") with T.allocate([1, 14, 14, 512], "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): - compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(T.load("int32", compute_22, (((i1_37*7168) + (i2_48*512)) + i3_49)), 255), 0) + compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): - T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = T.load("int32", compute_23, (((ax1_46*7168) + (ax2_47*512)) + ax3_49)).astype("uint8") + T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = compute_23[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)].astype("uint8") for ax1_47, ax2_48, ax3_50 in T.grid(14, 14, 512): - T_cast_77.data[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T.load("uint8", T_cast_79, (((ax1_47*7168) + (ax2_48*512)) + ax3_50)).astype("int16") + T_cast_77[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T_cast_79[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)].astype("int16") # fmt: on diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1a0dfd09a2df..6645983f5211 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -57,7 +57,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.store(A, i * 16 + j, 1) for i in range(0, 16): for j in range(0, 16): - T.evaluate(T.load("float32", A, i * 16 + j)) + T.evaluate(A[i * 16 + j]) for j in range(0, 16): T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 3e9e7fd33fd9..444cc9121e77 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -237,9 +237,8 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # body for i in T.serial(0, n): d_1.data[(i * stride_3)] = ( - T.load("float32", A_1.data, (i * stride)) - * T.load("float32", B_1.data, (i * stride_1)) - ) + T.load("float32", C_1.data, (i * stride_2)) + A_1.data[(i * stride)] * B_1.data[(i * stride_1)] + ) + C_1.data[(i * stride_2)] def test_fma(): diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 5ca9cf0da3c9..0c93f3a50382 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -465,7 +465,7 @@ def fail_match_load(a: T.handle) -> None: T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], ()) - T.evaluate(T.load("float32", sub_A.data, 0)) + T.evaluate(sub_A[0]) @T.prim_func diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 22f26ce0318a..00bcb710e24b 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -80,7 +80,7 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + D.data[vi * 128 + vj] = A.data[vi * 128 + vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -272,7 +272,7 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) + D.data[vi * 128 + vj] = A_global.data[vi * 128 + vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -481,7 +481,7 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + D_global.data[vi * 128 + vj] = A.data[vi * 128 + vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 5cc36c0df878..a098b2322792 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -183,7 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 + C[vi, vj] = B.data[vi * 128 + vj] + 1.0 @T.prim_func @@ -201,7 +201,7 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 + C[vi, vj] = B.data[vi * 16 + vj] + 1.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 9b844853f243..145b9af9eddd 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -482,7 +482,7 @@ def opaque_access_annotated_func(a: T.handle) -> None: # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) - T.store(C.data, i, T.load("float32", B.data, i)) + T.store(C.data, i, B.data[i]) @T.prim_func @@ -502,7 +502,7 @@ def compacted_opaque_access_annotated_func(a: T.handle) -> None: # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) - T.store(C.data, i, T.load("float32", B.data, i)) + T.store(C.data, i, B.data[i]) def test_elementwise(): diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index a91fa2591e00..98b894fbf733 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -34,14 +34,14 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T. PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), placeholder_33.data[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)], True) for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") T.store(Conv2dOutput_3, 0, 0, True) for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(Conv2dOutput_3, 0, (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34.data[((rc_3*16) + ax3_2)], "int32"))), True) + T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35.data[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) # fmt: on diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index ca3d4aa70d0b..be8c3c1f656f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -55,9 +55,9 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: for i in T.serial(0, 16): B_new = T.allocate([16], "float32", "global") for j in T.serial(0, 16): - B_new[j] = T.load("float32", A.data, ((i * 16) + j)) + 1.0 + B_new[j] = A.data[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): - C.data[((i * 16) + j)] = T.load("float32", B_new, j) * 2.0 + C.data[((i * 16) + j)] = B_new[j] * 2.0 @T.prim_func @@ -97,9 +97,9 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(0, 16): - B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 + B[j] = A.data[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * 2.0 + C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 @T.prim_func @@ -132,9 +132,9 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> for i in range(0, n): B = T.allocate([m], "float32", "global") for j in range(0, m): - B[j] = T.load("float32", A.data, i * m + j) + 1.0 + B[j] = A.data[i * m + j] + 1.0 for j in range(0, m): - C.data[i * m + j] = T.load("float32", B, j) * 2.0 + C.data[i * m + j] = B[j] * 2.0 @T.prim_func @@ -157,7 +157,7 @@ def flattened_predicate_func(a: T.handle, c: T.handle) -> None: for i, j in T.grid(5, 7): if i * 7 + j < 32: - C.data[i * 7 + j] = T.load("float32", A.data, i * 7 + j) + 1.0 + C.data[i * 7 + j] = A.data[i * 7 + j] + 1.0 @T.prim_func @@ -178,7 +178,7 @@ def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for x, z in T.grid(4, 8): - C.data[x * 8 + z] = T.load("float32", A.data, x * 8 + z) + 1.0 + C.data[x * 8 + z] = A.data[x * 8 + z] + 1.0 @T.prim_func @@ -205,9 +205,9 @@ def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: for i in range(0, 32): B = T.allocate((32,), "float32", "global") C = T.allocate((32,), "float32", "global") - B[i] = T.load("float32", A.data, i) + 1.0 - C[i] = T.load("float32", A.data, i) + T.load("float32", B, i) - D.data[i] = T.load("float32", C, i) * 2.0 + B[i] = A.data[i] + 1.0 + C[i] = A.data[i] + B[i] + D.data[i] = C[i] * 2.0 @T.prim_func @@ -241,10 +241,10 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: B_new = T.allocate([68], "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): - B_new[i1 * 17 + j] = T.load("float32", A.data, i0 * 64 + i1 * 16 + j) + 1.0 + B_new[i1 * 17 + j] = A.data[i0 * 64 + i1 * 16 + j] + 1.0 for i1 in T.serial(0, 4): for j in T.serial(0, 16): - C.data[i0 * 64 + i1 * 16 + j] = T.load("float32", B_new, i1 * 17 + j) * 2.0 + C.data[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 8b109172ea09..2422d2ebd9c5 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -545,9 +545,9 @@ def partitioned_concat(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16], dtype="float32") C = T.match_buffer(c, [32], dtype="float32") for i in T.serial(0, 16): - T.store(C.data, i, T.load("float32", A.data, i), True) + T.store(C.data, i, A.data[i], True) for i in T.serial(0, 16): - T.store(C.data, i + 16, T.load("float32", B.data, i + 16), True) + T.store(C.data, i + 16, B.data[i + 16], True) def test_explicit_partition_hint(): diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 1995695100cb..192666d115e4 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -304,7 +304,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -318,15 +318,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -341,10 +341,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -423,7 +423,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2.data[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3.data[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: @@ -436,15 +436,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13.data[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): T.store(Conv2dOutput_1, ff_1, 0, True) for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + T.store(Conv2dOutput_1, ff_1, Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14.data[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32"), True) for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15.data[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: @@ -457,16 +457,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, placeholder_19.data[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2], True) for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): T.store(Conv2dOutput_2, ff_2, 0, True) for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + T.store(Conv2dOutput_2, ff_2, Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20.data[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32"), True) for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21.data[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: @@ -480,16 +480,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, placeholder_29.data[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3], True) for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_3, ff_3, 0, True) for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + T.store(Conv2dOutput_3, ff_3, Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27.data[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32"), True) for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26.data[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28.data[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: @@ -519,15 +519,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, placeholder_7.data[i0_i1_fused * 4800 + i2 * 64 + i3], True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): T.store(Conv2dOutput, ff, 0, True) for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + T.store(Conv2dOutput, ff, Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8.data[rc * 64 + ff], "int32"), True) for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9.data[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) __tvm_meta__ = None # fmt: on diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index ed8ff329ebf4..32b443ecbc14 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -105,7 +105,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -119,15 +119,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -142,10 +142,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -215,7 +215,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71.data[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): dummy_allocate = T.allocate([1], "int32", "global") for ax3_outer_4 in T.serial(0, 3): @@ -223,9 +223,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol for ff_4 in T.serial(0, 64): T.store(Conv2dOutput_8, ff_4, 0, True) for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + T.store(Conv2dOutput_8, ff_4, (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72.data[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))), True) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73.data[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -256,7 +256,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71.data[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): dummy_allocate = T.allocate([1], "int32", "global") for ax3_outer_4 in T.serial(0, 3): @@ -264,9 +264,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol for ff_4 in T.serial(0, 64): T.store(Conv2dOutput_8, ff_4, 0, True) for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + T.store(Conv2dOutput_8, ff_4, (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72.data[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))), True) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73.data[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -338,7 +338,7 @@ def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) for ax3_outer_init, ax3_inner_init in T.grid(3, 64): T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init), T.uint8(0), True) for rv0_rv1_fused, ax3_outer, ax3_inner in T.grid(9, 3, 64): - T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner), T.max(T.load("uint8", tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)), T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), T.load("uint8", placeholder_1.data, ((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner), T.max(tensor_1.data[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)], T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), placeholder_1.data[((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)], T.uint8(0), dtype="uint8")), True) @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: @@ -350,7 +350,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) @T.prim_func def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: @@ -361,7 +361,7 @@ def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None # body for ax0_ax1_fused_2 in T.serial(0, 28): for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): - T.store(T_cast_1.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2), T.cast(T.load("uint8", placeholder_7.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)), "int16"), True) + T.store(T_cast_1.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2), T.cast(placeholder_7.data[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)], "int16"), True) @T.prim_func def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: @@ -375,7 +375,7 @@ def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.h # body for ax0_ax1_fused_3 in T.serial(0, 28): for ax2_3, ax3 in T.grid(28, 256): - T.store(T_concat_1.data, (((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3), T.if_then_else((224 <= ax3), T.load("uint8", placeholder_14.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)), T.if_then_else((192 <= ax3), T.load("uint8", placeholder_15.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)), T.if_then_else((64 <= ax3), T.load("uint8", placeholder_13.data, ((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)), T.load("uint8", placeholder_12.data, (((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)), dtype="uint8"), dtype="uint8"), dtype="uint8"), True) + T.store(T_concat_1.data, (((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3), T.if_then_else((224 <= ax3), placeholder_14.data[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)], T.if_then_else((192 <= ax3), placeholder_15.data[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)], T.if_then_else((64 <= ax3), placeholder_13.data[((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)], placeholder_12.data[(((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)], dtype="uint8"), dtype="uint8"), dtype="uint8"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: @@ -389,15 +389,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place PaddedInput = T.allocate([200704], "int16", "global") for i0_i1_fused in T.serial(0, 56): for i2, i3 in T.grid(56, 64): - T.store(PaddedInput, (((i0_i1_fused*3584) + (i2*64)) + i3), T.load("int16", placeholder_19.data, (((i0_i1_fused*3584) + (i2*64)) + i3)), True) + T.store(PaddedInput, (((i0_i1_fused*3584) + (i2*64)) + i3), placeholder_19.data[(((i0_i1_fused*3584) + (i2*64)) + i3)], True) for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): T.store(Conv2dOutput, ff, 0, True) for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, (T.load("int32", Conv2dOutput, ff) + (T.cast(T.load("int16", PaddedInput, ((ax0_ax1_fused_ax2_fused*64) + rc)), "int32")*T.cast(T.load("int16", placeholder_20.data, ((rc*64) + ff)), "int32"))), True) + T.store(Conv2dOutput, ff, (Conv2dOutput[ff] + (T.cast(PaddedInput[((ax0_ax1_fused_ax2_fused*64) + rc)], "int32")*T.cast(placeholder_20.data[((rc*64) + ff)], "int32"))), True) for ax3_inner_3 in T.serial(0, 64): - T.store(T_cast_3.data, ((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_inner_3)), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_3.data, ((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput[ax3_inner_3] + placeholder_21.data[ax3_inner_3]), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: @@ -411,14 +411,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla PaddedInput_1 = T.allocate([150528], "int16", "global") for i0_i1_fused_1 in T.serial(0, 28): for i2_1, i3_1 in T.grid(28, 192): - T.store(PaddedInput_1, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1), T.load("int16", placeholder_25.data, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)), True) + T.store(PaddedInput_1, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1), placeholder_25.data[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)], True) for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): Conv2dOutput_1 = T.allocate([1], "int32", "global") for ax3_1 in T.serial(0, 96): T.store(Conv2dOutput_1, 0, 0, True) for rc_1 in T.serial(0, 192): - T.store(Conv2dOutput_1, 0, (T.load("int32", Conv2dOutput_1, 0) + (T.cast(T.load("int16", PaddedInput_1, ((ax0_ax1_fused_ax2_fused_1*192) + rc_1)), "int32")*T.cast(T.load("int16", placeholder_26.data, ((rc_1*96) + ax3_1)), "int32"))), True) - T.store(T_cast_5.data, ((ax0_ax1_fused_ax2_fused_1*96) + ax3_1), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_1, 0) + T.load("int32", placeholder_27.data, ax3_1)), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(Conv2dOutput_1, 0, (Conv2dOutput_1[0] + (T.cast(PaddedInput_1[((ax0_ax1_fused_ax2_fused_1*192) + rc_1)], "int32")*T.cast(placeholder_26.data[((rc_1*96) + ax3_1)], "int32"))), True) + T.store(T_cast_5.data, ((ax0_ax1_fused_ax2_fused_1*96) + ax3_1), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_1[0] + placeholder_27.data[ax3_1]), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -433,10 +433,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: @@ -450,15 +450,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placehol PaddedInput_2 = T.allocate([150528], "int16", "global") for i0_i1_fused_2 in T.serial(0, 28): for i2_2, i3_2 in T.grid(28, 192): - T.store(PaddedInput_2, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2), T.load("int16", placeholder_33.data, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)), True) + T.store(PaddedInput_2, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2), placeholder_33.data[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)], True) for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): T.store(Conv2dOutput_2, ff_1, 0, True) for rc_2 in T.serial(0, 192): - T.store(Conv2dOutput_2, ff_1, (T.load("int32", Conv2dOutput_2, ff_1) + (T.cast(T.load("int16", PaddedInput_2, ((ax0_ax1_fused_ax2_fused_2*192) + rc_2)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_2*64) + ff_1)), "int32"))), True) + T.store(Conv2dOutput_2, ff_1, (Conv2dOutput_2[ff_1] + (T.cast(PaddedInput_2[((ax0_ax1_fused_ax2_fused_2*192) + rc_2)], "int32")*T.cast(placeholder_34.data[((rc_2*64) + ff_1)], "int32"))), True) for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_2, ax3_inner_4) + T.load("int32", placeholder_35.data, ax3_inner_4)), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_2[ax3_inner_4] + placeholder_35.data[ax3_inner_4]), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: @@ -473,10 +473,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_1 for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1), T.uint8(0), True) for rv0_rv1_fused_2, ax3_outer_2, ax3_inner_5 in T.grid(9, 3, 64): - T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5), T.max(T.load("uint8", tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)), T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), T.load("uint8", placeholder_37.data, (((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5), T.max(tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)], T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), placeholder_37.data[(((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_7 in T.serial(0, 28): for ax2_7, ax3_4 in T.grid(28, 192): - T.store(T_cast_11.data, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4), T.cast(T.load("uint8", tensor_3, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)), "int16"), True) + T.store(T_cast_11.data, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4), T.cast(tensor_3[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)], "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: @@ -490,14 +490,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.serial(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_41.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), placeholder_41.data[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)], True) for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): Conv2dOutput_3 = T.allocate([1], "int32", "global") for ax3_5 in T.serial(0, 32): T.store(Conv2dOutput_3, 0, 0, True) for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_42.data, ((rc_3*32) + ax3_5)), "int32"))), True) - T.store(T_cast_13.data, ((ax0_ax1_fused_ax2_fused_3*32) + ax3_5), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_43.data, ax3_5)), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + T.store(Conv2dOutput_3, 0, (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_42.data[((rc_3*32) + ax3_5)], "int32"))), True) + T.store(T_cast_13.data, ((ax0_ax1_fused_ax2_fused_3*32) + ax3_5), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_43.data[ax3_5]), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: @@ -511,14 +511,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(pla PaddedInput_4 = T.allocate([150528], "int16", "global") for i0_i1_fused_4 in T.serial(0, 28): for i2_4, i3_4 in T.grid(28, 192): - T.store(PaddedInput_4, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4), T.load("int16", placeholder_47.data, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)), True) + T.store(PaddedInput_4, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4), placeholder_47.data[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)], True) for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): Conv2dOutput_4 = T.allocate([1], "int32", "global") for ax3_6 in T.serial(0, 16): T.store(Conv2dOutput_4, 0, 0, True) for rc_4 in T.serial(0, 192): - T.store(Conv2dOutput_4, 0, (T.load("int32", Conv2dOutput_4, 0) + (T.cast(T.load("int16", PaddedInput_4, ((ax0_ax1_fused_ax2_fused_4*192) + rc_4)), "int32")*T.cast(T.load("int16", placeholder_48.data, ((rc_4*16) + ax3_6)), "int32"))), True) - T.store(T_cast_15.data, ((ax0_ax1_fused_ax2_fused_4*16) + ax3_6), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_4, 0) + T.load("int32", placeholder_49.data, ax3_6)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(Conv2dOutput_4, 0, (Conv2dOutput_4[0] + (T.cast(PaddedInput_4[((ax0_ax1_fused_ax2_fused_4*192) + rc_4)], "int32")*T.cast(placeholder_48.data[((rc_4*16) + ax3_6)], "int32"))), True) + T.store(T_cast_15.data, ((ax0_ax1_fused_ax2_fused_4*16) + ax3_6), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_4[0] + placeholder_49.data[ax3_6]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: @@ -532,14 +532,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed PaddedInput_5 = T.allocate([14400], "int16", "global") for i0_i1_fused_5 in T.serial(0, 30): for i2_5, i3_5 in T.grid(30, 16): - T.store(PaddedInput_5, (((i0_i1_fused_5*480) + (i2_5*16)) + i3_5), T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), T.load("int16", placeholder_53.data, ((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_5, (((i0_i1_fused_5*480) + (i2_5*16)) + i3_5), T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), placeholder_53.data[((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): Conv2dOutput_5 = T.allocate([1], "int32", "global") for ax3_7 in T.serial(0, 32): T.store(Conv2dOutput_5, 0, 0, True) for ry, rx, rc_5 in T.grid(3, 3, 16): - T.store(Conv2dOutput_5, 0, (T.load("int32", Conv2dOutput_5, 0) + (T.cast(T.load("int16", PaddedInput_5, (((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)), "int32")*T.cast(T.load("int16", placeholder_54.data, ((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)), "int32"))), True) - T.store(T_cast_17.data, ((ax0_ax1_fused_ax2_fused_5*32) + ax3_7), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_5, 0) + T.load("int32", placeholder_55.data, ax3_7)), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + T.store(Conv2dOutput_5, 0, (Conv2dOutput_5[0] + (T.cast(PaddedInput_5[(((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)], "int32")*T.cast(placeholder_54.data[((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)], "int32"))), True) + T.store(T_cast_17.data, ((ax0_ax1_fused_ax2_fused_5*32) + ax3_7), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_5[0] + placeholder_55.data[ax3_7]), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: @@ -553,16 +553,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed PaddedInput_6 = T.allocate([86400], "int16", "global") for i0_i1_fused_6 in T.serial(0, 30): for i2_6, i3_6 in T.grid(30, 96): - T.store(PaddedInput_6, (((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6), T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), T.load("int16", placeholder_59.data, ((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_6, (((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6), T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), placeholder_59.data[((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): Conv2dOutput_6 = T.allocate([64], "int32", "global") for ax3_outer_3 in T.serial(0, 2): for ff_2 in T.serial(0, 64): T.store(Conv2dOutput_6, ff_2, 0, True) for ry_1, rx_1, rc_6 in T.grid(3, 3, 96): - T.store(Conv2dOutput_6, ff_2, (T.load("int32", Conv2dOutput_6, ff_2) + (T.cast(T.load("int16", PaddedInput_6, (((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)), "int32")*T.cast(T.load("int16", placeholder_60.data, (((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)), "int32"))), True) + T.store(Conv2dOutput_6, ff_2, (Conv2dOutput_6[ff_2] + (T.cast(PaddedInput_6[(((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)], "int32")*T.cast(placeholder_60.data[(((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)], "int32"))), True) for ax3_inner_6 in T.serial(0, 64): - T.store(T_cast_19.data, (((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_6, ax3_inner_6) + T.load("int32", placeholder_61.data, ((ax3_outer_3*64) + ax3_inner_6))), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_19.data, (((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_6[ax3_inner_6] + placeholder_61.data[((ax3_outer_3*64) + ax3_inner_6)]), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -576,15 +576,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: @@ -598,16 +598,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71.data[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ax3_outer_4 in T.serial(0, 3): for ff_4 in T.serial(0, 64): T.store(Conv2dOutput_8, ff_4, 0, True) for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + T.store(Conv2dOutput_8, ff_4, (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72.data[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))), True) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73.data[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -1111,7 +1111,7 @@ def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_tran T_layout_trans_1 = T.match_buffer(T_layout_trans, [1, 1, 24, 12, 3], dtype="float32") # body for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): - T.store(T_layout_trans_1.data, ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner, T.load("float32", placeholder_1.data, ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3), True) + T.store(T_layout_trans_1.data, ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner, placeholder_1.data[ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3], True) @T.prim_func def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: @@ -1123,7 +1123,7 @@ def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeh # body data_pad = T.allocate([1, 1, 26, 14, 3], "float32", "global") for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): - T.store(data_pad, i0_i1_fused_i2_fused * 42 + i3 * 3 + i4, T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, T.load("float32", placeholder_4.data, i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39), T.float32(0), dtype="float32"), True) + T.store(data_pad, i0_i1_fused_i2_fused * 42 + i3 * 3 + i4, T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4.data[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32"), True) for n_oc_chunk_fused_oh_fused in T.serial(0, 24): conv2d_NCHWc_global = T.allocate([1, 1, 1, 12, 3], "float32", "global") for oc_block_c_init in T.serial(0, 3): @@ -1152,31 +1152,31 @@ def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeh T.store(conv2d_NCHWc_global, oc_block_c_init + 33, T.float32(0), True) for kh, kw, ic_inner in T.grid(3, 3, 3): for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c, T.load("float32", conv2d_NCHWc_global, oc_block_c) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c, conv2d_NCHWc_global[oc_block_c] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 3, T.load("float32", conv2d_NCHWc_global, oc_block_c + 3) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 3, conv2d_NCHWc_global[oc_block_c + 3] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 6, T.load("float32", conv2d_NCHWc_global, oc_block_c + 6) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 6, conv2d_NCHWc_global[oc_block_c + 6] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 9, T.load("float32", conv2d_NCHWc_global, oc_block_c + 9) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 9, conv2d_NCHWc_global[oc_block_c + 9] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 12, T.load("float32", conv2d_NCHWc_global, oc_block_c + 12) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 12, conv2d_NCHWc_global[oc_block_c + 12] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 15, T.load("float32", conv2d_NCHWc_global, oc_block_c + 15) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 15, conv2d_NCHWc_global[oc_block_c + 15] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 18, T.load("float32", conv2d_NCHWc_global, oc_block_c + 18) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 18, conv2d_NCHWc_global[oc_block_c + 18] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 21, T.load("float32", conv2d_NCHWc_global, oc_block_c + 21) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 21, conv2d_NCHWc_global[oc_block_c + 21] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 24, T.load("float32", conv2d_NCHWc_global, oc_block_c + 24) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 24, conv2d_NCHWc_global[oc_block_c + 24] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 27, T.load("float32", conv2d_NCHWc_global, oc_block_c + 27) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 27, conv2d_NCHWc_global[oc_block_c + 27] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 30, T.load("float32", conv2d_NCHWc_global, oc_block_c + 30) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 30, conv2d_NCHWc_global[oc_block_c + 30] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 33, T.load("float32", conv2d_NCHWc_global, oc_block_c + 33) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + T.store(conv2d_NCHWc_global, oc_block_c + 33, conv2d_NCHWc_global[oc_block_c + 33] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) for ow_inner, oc_block in T.grid(12, 3): - T.store(conv2d_NCHWc_1.data, n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block, T.load("float32", conv2d_NCHWc_global, ow_inner * 3 + oc_block), True) + T.store(conv2d_NCHWc_1.data, n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block, conv2d_NCHWc_global[ow_inner * 3 + oc_block], True) @T.prim_func def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: @@ -1194,18 +1194,18 @@ def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) for k in T.serial(0, 12): - T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) + T.store(T_softmax_maxelem, 0, T.max(T_softmax_maxelem[0], placeholder_11.data[ax0_ax1_fused_ax2_fused * 12 + k]), True) T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") for i3 in T.serial(0, 12): - T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) + T.store(T_softmax_exp, i3, T.exp(placeholder_11.data[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32"), True) T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") T.store(T_softmax_expsum, 0, T.float32(0), True) for k in T.serial(0, 12): - T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + T.store(T_softmax_expsum, 0, T_softmax_expsum[0] + T_softmax_exp[k], True) for i3 in T.serial(0, 12): - T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + T.store(T_softmax_norm, i3, T_softmax_exp[i3] / T_softmax_expsum[0], True) for ax3 in T.serial(0, 12): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused * 12 + ax3, (T.load("float32", placeholder_12.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3) + T.load("float32", placeholder_13.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24))) * T.load("float32", placeholder_14.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)) + T.load("float32", placeholder_15.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)), True) + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused * 12 + ax3, (placeholder_12.data[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] + placeholder_13.data[T.floordiv(ax0_ax1_fused_ax2_fused, 24)]) * placeholder_14.data[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] + placeholder_15.data[T.floordiv(ax0_ax1_fused_ax2_fused, 24)], True) @T.prim_func def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: @@ -1236,39 +1236,39 @@ def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, T.store(compute_global, x_c_init + 42, T.float32(0), True) for k_outer in T.serial(0, 12): for x_c in T.serial(0, 6): - T.store(compute_global, x_c, T.load("float32", compute_global, x_c) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c, compute_global[x_c] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 6, T.load("float32", compute_global, x_c + 6) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 6, compute_global[x_c + 6] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 12, T.load("float32", compute_global, x_c + 12) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 12, compute_global[x_c + 12] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 18, T.load("float32", compute_global, x_c + 18) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 18, compute_global[x_c + 18] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 24, T.load("float32", compute_global, x_c + 24) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 24, compute_global[x_c + 24] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 30, T.load("float32", compute_global, x_c + 30) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 30, compute_global[x_c + 30] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 36, T.load("float32", compute_global, x_c + 36) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 36, compute_global[x_c + 36] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 42, T.load("float32", compute_global, x_c + 42) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + T.store(compute_global, x_c + 42, compute_global[x_c + 42] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner, T.load("float32", compute_global, x_inner_inner), True) + T.store(compute, x_inner_inner, compute_global[x_inner_inner], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 6, T.load("float32", compute_global, x_inner_inner + 6), True) + T.store(compute, x_inner_inner + 6, compute_global[x_inner_inner + 6], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 12, T.load("float32", compute_global, x_inner_inner + 12), True) + T.store(compute, x_inner_inner + 12, compute_global[x_inner_inner + 12], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 18, T.load("float32", compute_global, x_inner_inner + 18), True) + T.store(compute, x_inner_inner + 18, compute_global[x_inner_inner + 18], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 24, T.load("float32", compute_global, x_inner_inner + 24), True) + T.store(compute, x_inner_inner + 24, compute_global[x_inner_inner + 24], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 30, T.load("float32", compute_global, x_inner_inner + 30), True) + T.store(compute, x_inner_inner + 30, compute_global[x_inner_inner + 30], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 36, T.load("float32", compute_global, x_inner_inner + 36), True) + T.store(compute, x_inner_inner + 36, compute_global[x_inner_inner + 36], True) for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 42, T.load("float32", compute_global, x_inner_inner + 42), True) + T.store(compute, x_inner_inner + 42, compute_global[x_inner_inner + 42], True) for ax0_inner_inner, ax1_inner_inner in T.grid(8, 6): - T.store(T_relu_1.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner, T.max(T.load("float32", compute, ax0_inner_inner * 6 + ax1_inner_inner), T.float32(0)), True) + T.store(T_relu_1.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner, T.max(compute[ax0_inner_inner * 6 + ax1_inner_inner], T.float32(0)), True) @T.prim_func def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: @@ -1278,7 +1278,7 @@ def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle T_reshape_1 = T.match_buffer(T_reshape, [72, 12], dtype="float32") # body for ax0, ax1_inner in T.grid(72, 12): - T.store(T_reshape_1.data, ax0 * 12 + ax1_inner, T.load("float32", placeholder_21.data, ax0 * 12 + ax1_inner), True) + T.store(T_reshape_1.data, ax0 * 12 + ax1_inner, placeholder_21.data[ax0 * 12 + ax1_inner], True) @T.prim_func def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: @@ -1288,7 +1288,7 @@ def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_tra T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [1, 3, 24, 12], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): - T.store(T_layout_trans_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_23.data, ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused), True) + T.store(T_layout_trans_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, placeholder_23.data[ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused], True) @T.prim_func def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: @@ -1298,7 +1298,7 @@ def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle T_reshape_3 = T.match_buffer(T_reshape_2, [1, 3, 24, 12], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): - T.store(T_reshape_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_25.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner), True) + T.store(T_reshape_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, placeholder_25.data[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner], True) @T.prim_func def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: @@ -1313,18 +1313,18 @@ def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27 with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) for k in T.serial(0, 12): - T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) + T.store(T_softmax_maxelem, 0, T.max(T_softmax_maxelem[0], placeholder_28.data[ax0_ax1_fused_ax2_fused * 12 + k]), True) T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") for i3 in T.serial(0, 12): - T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) + T.store(T_softmax_exp, i3, T.exp(placeholder_28.data[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32"), True) T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") T.store(T_softmax_expsum, 0, T.float32(0), True) for k in T.serial(0, 12): - T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + T.store(T_softmax_expsum, 0, T_softmax_expsum[0] + T_softmax_exp[k], True) for i3 in T.serial(0, 12): - T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + T.store(T_softmax_norm, i3, T_softmax_exp[i3] / T_softmax_expsum[0], True) for ax3 in T.serial(0, 12): - T.store(T_add_3.data, ax0_ax1_fused_ax2_fused * 12 + ax3, T.load("float32", placeholder_29.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3), True) + T.store(T_add_3.data, ax0_ax1_fused_ax2_fused * 12 + ax3, placeholder_29.data[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3], True) @T.prim_func def run_model(data: T.handle, output: T.handle) -> None: diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index ab40c646391c..9adfb7639ada 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -79,7 +79,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -93,15 +93,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -116,10 +116,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -146,8 +146,8 @@ def run_model(input: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 1117472), dtype="handle") - sid_8_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 0), dtype="handle") + sid_9_let: T.handle = T.address_of(slow_memory_1_buffer_var.data[1117472], dtype="handle") + sid_8_let: T.handle = T.address_of(slow_memory_1_buffer_var.data[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @@ -159,14 +159,14 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - tensor_2_let: T.handle = T.address_of(T.load("uint8", fast_memory_6_buffer_var.data, 0), dtype="handle") + tensor_2_let: T.handle = T.address_of(fast_memory_6_buffer_var.data[0], dtype="handle") for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2], T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, placeholder_29.data[ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): - T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True) + T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16"), True) @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: @@ -177,7 +177,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): - T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True) + T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(placeholder_4.data[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5.data[0], True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: @@ -188,17 +188,17 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_7_let: T.handle = T.address_of(T.load("uint8", slow_memory_5_buffer_var.data, 802816), dtype="handle") + PaddedInput_7_let: T.handle = T.address_of(slow_memory_5_buffer_var.data[802816], dtype="handle") for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): - T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65.data[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7_let: T.handle = T.address_of(T.load("uint8", fast_memory_4_buffer_var.data, 0), dtype="handle") + Conv2dOutput_7_let: T.handle = T.address_of(fast_memory_4_buffer_var.data[0], dtype="handle") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7_let, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True) + T.store(Conv2dOutput_7_let, ff_3, Conv2dOutput_7_let[ff_3] + T.cast(PaddedInput_7_let[ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7], "int32") * T.cast(placeholder_66.data[ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3], "int32"), True) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_7_let[ax3_inner_7] + placeholder_67.data[ax3_inner_7], 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) # fmt: on @@ -259,7 +259,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2.data[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3.data[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: @@ -272,15 +272,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13.data[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): T.store(Conv2dOutput_1, ff_1, 0, True) for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + T.store(Conv2dOutput_1, ff_1, Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14.data[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32"), True) for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15.data[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: @@ -293,16 +293,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, placeholder_19.data[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2], True) for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): T.store(Conv2dOutput_2, ff_2, 0, True) for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + T.store(Conv2dOutput_2, ff_2, Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20.data[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32"), True) for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21.data[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: @@ -316,16 +316,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, placeholder_29.data[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3], True) for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_3, ff_3, 0, True) for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + T.store(Conv2dOutput_3, ff_3, Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27.data[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32"), True) for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26.data[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28.data[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8"), True) @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -355,15 +355,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, placeholder_7.data[i0_i1_fused * 4800 + i2 * 64 + i3], True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): T.store(Conv2dOutput, ff, 0, True) for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + T.store(Conv2dOutput, ff, Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8.data[rc * 64 + ff], "int32"), True) for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9.data[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) # fmt: on @@ -378,7 +378,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2.data[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3.data[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: @@ -389,18 +389,18 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle") + PaddedInput_3_let: T.handle = T.address_of(global_workspace_5_buffer_var.data[6480000], dtype="handle") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, placeholder_29.data[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3], True) for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle") + Conv2dOutput_3_let: T.handle = T.address_of(global_workspace_5_buffer_var.data[7200000], dtype="handle") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_3_let, ff_3, 0, True) for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + T.store(Conv2dOutput_3_let, ff_3, Conv2dOutput_3_let[ff_3] + T.cast(PaddedInput_3_let[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27.data[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32"), True) for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26.data[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28.data[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: @@ -410,18 +410,18 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle") + PaddedInput_2_let: T.handle = T.address_of(global_workspace_4_buffer_var.data[7200000], dtype="handle") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, placeholder_19.data[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2], True) for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle") + Conv2dOutput_2_let: T.handle = T.address_of(global_workspace_4_buffer_var.data[7920000], dtype="handle") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): T.store(Conv2dOutput_2_let, ff_2, 0, True) for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + T.store(Conv2dOutput_2_let, ff_2, Conv2dOutput_2_let[ff_2] + T.cast(PaddedInput_2_let[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20.data[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32"), True) for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21.data[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: @@ -431,17 +431,17 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle") + PaddedInput_let: T.handle = T.address_of(global_workspace_2_buffer_var.data[7200000], dtype="handle") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, placeholder_7.data[i0_i1_fused * 4800 + i2 * 64 + i3], True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle") + Conv2dOutput_let: T.handle = T.address_of(global_workspace_2_buffer_var.data[7920000], dtype="handle") for ff in T.serial(0, 64): T.store(Conv2dOutput_let, ff, 0, True) for rc in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + T.store(Conv2dOutput_let, ff, Conv2dOutput_let[ff] + T.cast(PaddedInput_let[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8.data[rc * 64 + ff], "int32"), True) for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9.data[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: @@ -451,17 +451,17 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle") + PaddedInput_1_let: T.handle = T.address_of(global_workspace_3_buffer_var.data[0], dtype="handle") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13.data[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle") + Conv2dOutput_1_let: T.handle = T.address_of(global_workspace_3_buffer_var.data[7200000], dtype="handle") for ff_1 in T.serial(0, 64): T.store(Conv2dOutput_1_let, ff_1, 0, True) for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + T.store(Conv2dOutput_1_let, ff_1, Conv2dOutput_1_let[ff_1] + T.cast(PaddedInput_1_let[ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1], "int32") * T.cast(placeholder_14.data[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32"), True) for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15.data[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func def run_model(input: T.handle, global_workspace_0_var: T.handle, output: T.handle) -> None: @@ -469,10 +469,10 @@ def run_model(input: T.handle, global_workspace_0_var: T.handle, output: T.handl # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle") - sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle") - sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") - sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + sid_2_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[5760000], dtype="handle") + sid_6_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[0], dtype="handle") + sid_7_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[6480000], dtype="handle") + sid_8_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[6480000], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 34e526ae5173..3d90687118b3 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -37,7 +37,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -51,15 +51,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -74,10 +74,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 19dc81290e16..bc66f6ddd90f 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -333,7 +333,7 @@ def opaque_access_during_complete(a: T.handle) -> None: # error for i, j in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) - T.evaluate(T.load("float32", A.data, vi * 16 + vj)) + T.evaluate(A[vi * 16 + vj]) def test_opaque_access_during_complete(): @@ -415,7 +415,7 @@ def intrin_except_unassign(a: T.handle) -> None: def intrin_except_assign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - A[0, 0] = T.load(A, A, A) # error + A[0, 0] = A[A] # error def test_tvm_exception_catch(): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 51a4ce7960a8..da685a949527 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -99,12 +99,10 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: T.store( packedB, T.ramp(((x * 32768) + (y * 32)), 1, 32), - T.load( - "float32x32", - B_1.data, + B_1.data[ T.ramp(((y * 1024) + (x * 32)), 1, 32), T.broadcast(True, 32), - ), + ], T.broadcast(True, 32), ) for x_outer in T.parallel(0, 32): @@ -123,27 +121,21 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: C_global, T.ramp((x_c * 32), 1, 32), ( - T.load( - "float32x32", - C_global, + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ) + ] + ( T.broadcast( - T.load( - "float32", - A_1.data, + A_1.data[ (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), - ), + ], 32, ) - * T.load( - "float32x32", - packedB, + * packedB[ T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), T.broadcast(True, 32), - ) + ] ) ), T.broadcast(True, 32), @@ -152,30 +144,24 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: C_global, T.ramp((x_c * 32), 1, 32), ( - T.load( - "float32x32", - C_global, + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ) + ] + ( T.broadcast( - T.load( - "float32", - A_1.data, + A_1.data[ ( (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 1 ), - ), + ], 32, ) - * T.load( - "float32x32", - packedB, + * packedB[ T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32), T.broadcast(True, 32), - ) + ] ) ), T.broadcast(True, 32), @@ -184,30 +170,24 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: C_global, T.ramp((x_c * 32), 1, 32), ( - T.load( - "float32x32", - C_global, + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ) + ] + ( T.broadcast( - T.load( - "float32", - A_1.data, + A_1.data[ ( (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 2 ), - ), + ], 32, ) - * T.load( - "float32x32", - packedB, + * packedB[ T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32), T.broadcast(True, 32), - ) + ] ) ), T.broadcast(True, 32), @@ -216,30 +196,24 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: C_global, T.ramp((x_c * 32), 1, 32), ( - T.load( - "float32x32", - C_global, + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ) + ] + ( T.broadcast( - T.load( - "float32", - A_1.data, + A_1.data[ ( (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 3 ), - ), + ], 32, ) - * T.load( - "float32x32", - packedB, + * packedB[ T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32), T.broadcast(True, 32), - ) + ] ) ), T.broadcast(True, 32), @@ -248,7 +222,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: for y_inner in T.serial(0, 32): C_1.data[ ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner) - ] = T.load("float32", C_global, ((x_inner * 32) + y_inner)) + ] = C_global[((x_inner * 32) + y_inner)] def test_opt_gemm_lower(): @@ -282,11 +256,11 @@ def mmult( # body assert num_args == 3, "mmult: num_args should be 3" arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = T.load("int32", arg_type_ids, 0) + arg0_code: T.int32 = arg_type_ids[0] arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = T.load("int32", arg_type_ids, 1) + arg1_code: T.int32 = arg_type_ids[1] arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = T.load("int32", arg_type_ids, 2) + arg2_code: T.int32 = arg_type_ids[2] A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A, "storage_alignment", 128) arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") @@ -318,14 +292,14 @@ def mmult( T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) ), "arg0.dtype is expected to be float32" assert 1024 == T.cast( - T.load("int64", arg0_shape, 0), "int32" + arg0_shape[0], "int32" ), "Argument arg0.shape[0] has an unsatisfied constraint" assert 1024 == T.cast( - T.load("int64", arg0_shape, 1), "int32" + arg0_shape[1], "int32" ), "Argument arg0.shape[1] has an unsatisfied constraint" if not (T.isnullptr(arg0_strides, dtype="bool")): - assert (1 == T.cast(T.load("int64", arg0_strides, 1), "int32")) and ( - 1024 == T.cast(T.load("int64", arg0_strides, 0), "int32") + assert (1 == T.cast(arg0_strides[1], "int32")) and ( + 1024 == T.cast(arg0_strides[0], "int32") ), "arg0.strides: expected to be compact array" T.evaluate(0) assert T.uint64(0) == T.tvm_struct_get( @@ -343,14 +317,14 @@ def mmult( T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) ), "arg1.dtype is expected to be float32" assert 1024 == T.cast( - T.load("int64", arg1_shape, 0), "int32" + arg1_shape[0], "int32" ), "Argument arg1.shape[0] has an unsatisfied constraint" assert 1024 == T.cast( - T.load("int64", arg1_shape, 1), "int32" + arg1_shape[1], "int32" ), "Argument arg1.shape[1] has an unsatisfied constraint" if not (T.isnullptr(arg1_strides, dtype="bool")): - assert (1 == T.cast(T.load("int64", arg1_strides, 1), "int32")) and ( - 1024 == T.cast(T.load("int64", arg1_strides, 0), "int32") + assert (1 == T.cast(arg1_strides[1], "int32")) and ( + 1024 == T.cast(arg1_strides[0], "int32") ), "arg1.strides: expected to be compact array" T.evaluate(0) assert T.uint64(0) == T.tvm_struct_get( @@ -371,14 +345,14 @@ def mmult( T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) ), "arg2.dtype is expected to be float32" assert 1024 == T.cast( - T.load("int64", arg2_shape, 0), "int32" + arg2_shape[0], "int32" ), "Argument arg2.shape[0] has an unsatisfied constraint" assert 1024 == T.cast( - T.load("int64", arg2_shape, 1), "int32" + arg2_shape[1], "int32" ), "Argument arg2.shape[1] has an unsatisfied constraint" if not (T.isnullptr(arg2_strides, dtype="bool")): - assert (1 == T.cast(T.load("int64", arg2_strides, 1), "int32")) and ( - 1024 == T.cast(T.load("int64", arg2_strides, 0), "int32") + assert (1 == T.cast(arg2_strides[1], "int32")) and ( + 1024 == T.cast(arg2_strides[0], "int32") ), "arg2.strides: expected to be compact array" T.evaluate(0) assert T.uint64(0) == T.tvm_struct_get( @@ -404,12 +378,10 @@ def mmult( T.store( packedB, T.ramp(((x * 32768) + (y * 32)), 1, 32), - T.load( - "float32x32", - B, + B[ T.ramp(((y * 1024) + (x * 32)), 1, 32), T.broadcast(True, 32), - ), + ], T.broadcast(True, 32), ) for x_outer in T.parallel(0, 32): @@ -438,28 +410,22 @@ def mmult( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4) ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, + ], + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ), + ], dtype="float32x32", ), T.broadcast(True, 32), @@ -471,9 +437,7 @@ def mmult( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ( ((x_outer * 32768) + (x_c * 1024)) @@ -481,23 +445,19 @@ def mmult( ) + 1 ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 ), T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, + ], + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ), + ], dtype="float32x32", ), T.broadcast(True, 32), @@ -509,9 +469,7 @@ def mmult( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ( ((x_outer * 32768) + (x_c * 1024)) @@ -519,23 +477,19 @@ def mmult( ) + 2 ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 ), T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, + ], + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ), + ], dtype="float32x32", ), T.broadcast(True, 32), @@ -547,9 +501,7 @@ def mmult( T.uint32(97), T.uint32(3), T.broadcast( - T.load( - "float32", - A, + A[ ( ( ((x_outer * 32768) + (x_c * 1024)) @@ -557,23 +509,19 @@ def mmult( ) + 3 ), - ), + ], 32, ), - T.load( - "float32x32", - packedB, + packedB[ T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 ), T.broadcast(True, 32), - ), - T.load( - "float32x32", - C_global, + ], + C_global[ T.ramp((x_c * 32), 1, 32), T.broadcast(True, 32), - ), + ], dtype="float32x32", ), T.broadcast(True, 32), @@ -585,7 +533,7 @@ def mmult( (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner ) - ] = T.load("float32", C_global, ((x_inner * 32) + y_inner)) + ] = C_global[((x_inner * 32) + y_inner)] if T.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0: T.evaluate(T.tvm_throw_last_error(dtype="int32")) if T.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0: @@ -1130,9 +1078,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1155,7 +1101,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61440 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1173,9 +1119,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1198,7 +1142,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61408 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1216,9 +1160,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1241,7 +1183,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61376 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1259,9 +1201,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1284,7 +1224,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61344 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1302,9 +1242,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1327,7 +1265,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61312 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1345,9 +1283,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1370,7 +1306,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61280 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1388,9 +1324,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1413,7 +1347,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61248 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1431,9 +1365,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1456,7 +1388,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61216 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1474,9 +1406,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1499,7 +1429,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61184 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1517,9 +1447,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1542,7 +1470,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61152 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1560,9 +1488,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1585,7 +1511,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61120 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1603,9 +1529,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1628,7 +1552,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61088 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1646,9 +1570,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1671,7 +1593,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61056 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1689,9 +1611,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1714,7 +1634,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 61024 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1732,9 +1652,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1757,7 +1675,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 60992 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1772,9 +1690,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - T.load( - "float16", - A_1.data, + A_1.data[ ( ( ( @@ -1794,7 +1710,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) - 60960 ), - ), + ], T.float16(0), dtype="float16", ) @@ -1802,9 +1718,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: T.store( W_shared, T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8), - T.load( - "float16x8", - W_1.data, + W_1.data[ T.ramp( ( ( @@ -1820,16 +1734,14 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: 8, ), T.broadcast(True, 8), - ), + ], T.broadcast(True, 8), ) with T.launch_thread(tx, 32): T.store( W_shared, T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8), - T.load( - "float16x8", - W_1.data, + W_1.data[ T.ramp( ( ( @@ -1848,16 +1760,14 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: 8, ), T.broadcast(True, 8), - ), + ], T.broadcast(True, 8), ) with T.launch_thread(tx, 32): T.store( W_shared, T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8), - T.load( - "float16x8", - W_1.data, + W_1.data[ T.ramp( ( ( @@ -1876,16 +1786,14 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: 8, ), T.broadcast(True, 8), - ), + ], T.broadcast(True, 8), ) with T.launch_thread(tx, 32): T.store( W_shared, T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8), - T.load( - "float16x8", - W_1.data, + W_1.data[ T.ramp( ( ( @@ -1904,16 +1812,14 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: 8, ), T.broadcast(True, 8), - ), + ], T.broadcast(True, 8), ) with T.launch_thread(tx, 32): T.store( W_shared, T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8), - T.load( - "float16x8", - W_1.data, + W_1.data[ T.ramp( ( ( @@ -1932,16 +1838,14 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: 8, ), T.broadcast(True, 8), - ), + ], T.broadcast(True, 8), ) with T.launch_thread(tx, 32): T.store( W_shared, T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8), - T.load( - "float16x8", - W_1.data, + W_1.data[ T.ramp( ( ( @@ -1960,7 +1864,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: 8, ), T.broadcast(True, 8), - ), + ], T.broadcast(True, 8), ) for ic_inner in T.serial(0, 2): @@ -2422,11 +2326,11 @@ def opt_conv_tensorcore_mod_host( stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") assert num_args == 3, "default_function: num_args should be 3" arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = T.load("int32", arg_type_ids, 0) + arg0_code: T.int32 = arg_type_ids[0] arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = T.load("int32", arg_type_ids, 1) + arg1_code: T.int32 = arg_type_ids[1] arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = T.load("int32", arg_type_ids, 2) + arg2_code: T.int32 = arg_type_ids[2] A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A, "storage_alignment", 128) arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") @@ -2458,38 +2362,38 @@ def opt_conv_tensorcore_mod_host( T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) ), "arg0.dtype is expected to be float16" assert 16 == T.cast( - T.load("int64", arg0_shape, 0), "int32" + arg0_shape[0], "int32" ), "Argument arg0.shape[0] has an unsatisfied constraint" assert 14 == T.cast( - T.load("int64", arg0_shape, 1), "int32" + arg0_shape[1], "int32" ), "Argument arg0.shape[1] has an unsatisfied constraint" assert 14 == T.cast( - T.load("int64", arg0_shape, 2), "int32" + arg0_shape[2], "int32" ), "Argument arg0.shape[2] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg0_shape, 3), "int32" + arg0_shape[3], "int32" ), "Argument arg0.shape[3] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg0_shape, 4), "int32" + arg0_shape[4], "int32" ), "Argument arg0.shape[4] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg0_shape, 5), "int32" + arg0_shape[5], "int32" ), "Argument arg0.shape[5] has an unsatisfied constraint" if not (T.isnullptr(arg0_strides, dtype="bool")): assert ( ( ( ( - (1 == T.cast(T.load("int64", arg0_strides, 5), "int32")) - and (16 == T.cast(T.load("int64", arg0_strides, 4), "int32")) + (1 == T.cast(arg0_strides[5], "int32")) + and (16 == T.cast(arg0_strides[4], "int32")) ) - and (256 == T.cast(T.load("int64", arg0_strides, 3), "int32")) + and (256 == T.cast(arg0_strides[3], "int32")) ) - and (4096 == T.cast(T.load("int64", arg0_strides, 2), "int32")) + and (4096 == T.cast(arg0_strides[2], "int32")) ) - and (57344 == T.cast(T.load("int64", arg0_strides, 1), "int32")) + and (57344 == T.cast(arg0_strides[1], "int32")) ) and ( - 802816 == T.cast(T.load("int64", arg0_strides, 0), "int32") + 802816 == T.cast(arg0_strides[0], "int32") ), "arg0.strides: expected to be compact array" T.evaluate(0) assert T.uint64(0) == T.tvm_struct_get( @@ -2507,38 +2411,38 @@ def opt_conv_tensorcore_mod_host( T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) ), "arg1.dtype is expected to be float16" assert 3 == T.cast( - T.load("int64", arg1_shape, 0), "int32" + arg1_shape[0], "int32" ), "Argument arg1.shape[0] has an unsatisfied constraint" assert 3 == T.cast( - T.load("int64", arg1_shape, 1), "int32" + arg1_shape[1], "int32" ), "Argument arg1.shape[1] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg1_shape, 2), "int32" + arg1_shape[2], "int32" ), "Argument arg1.shape[2] has an unsatisfied constraint" assert 32 == T.cast( - T.load("int64", arg1_shape, 3), "int32" + arg1_shape[3], "int32" ), "Argument arg1.shape[3] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg1_shape, 4), "int32" + arg1_shape[4], "int32" ), "Argument arg1.shape[4] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg1_shape, 5), "int32" + arg1_shape[5], "int32" ), "Argument arg1.shape[5] has an unsatisfied constraint" if not (T.isnullptr(arg1_strides, dtype="bool")): assert ( ( ( ( - (1 == T.cast(T.load("int64", arg1_strides, 5), "int32")) - and (16 == T.cast(T.load("int64", arg1_strides, 4), "int32")) + (1 == T.cast(arg1_strides[5], "int32")) + and (16 == T.cast(arg1_strides[4], "int32")) ) - and (256 == T.cast(T.load("int64", arg1_strides, 3), "int32")) + and (256 == T.cast(arg1_strides[3], "int32")) ) - and (8192 == T.cast(T.load("int64", arg1_strides, 2), "int32")) + and (8192 == T.cast(arg1_strides[2], "int32")) ) - and (131072 == T.cast(T.load("int64", arg1_strides, 1), "int32")) + and (131072 == T.cast(arg1_strides[1], "int32")) ) and ( - 393216 == T.cast(T.load("int64", arg1_strides, 0), "int32") + 393216 == T.cast(arg1_strides[0], "int32") ), "arg1.strides: expected to be compact array" T.evaluate(0) assert T.uint64(0) == T.tvm_struct_get( @@ -2559,38 +2463,38 @@ def opt_conv_tensorcore_mod_host( T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) ), "arg2.dtype is expected to be float32" assert 16 == T.cast( - T.load("int64", arg2_shape, 0), "int32" + arg2_shape[0], "int32" ), "Argument arg2.shape[0] has an unsatisfied constraint" assert 14 == T.cast( - T.load("int64", arg2_shape, 1), "int32" + arg2_shape[1], "int32" ), "Argument arg2.shape[1] has an unsatisfied constraint" assert 14 == T.cast( - T.load("int64", arg2_shape, 2), "int32" + arg2_shape[2], "int32" ), "Argument arg2.shape[2] has an unsatisfied constraint" assert 32 == T.cast( - T.load("int64", arg2_shape, 3), "int32" + arg2_shape[3], "int32" ), "Argument arg2.shape[3] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg2_shape, 4), "int32" + arg2_shape[4], "int32" ), "Argument arg2.shape[4] has an unsatisfied constraint" assert 16 == T.cast( - T.load("int64", arg2_shape, 5), "int32" + arg2_shape[5], "int32" ), "Argument arg2.shape[5] has an unsatisfied constraint" if not (T.isnullptr(arg2_strides, dtype="bool")): assert ( ( ( ( - (1 == T.cast(T.load("int64", arg2_strides, 5), "int32")) - and (16 == T.cast(T.load("int64", arg2_strides, 4), "int32")) + (1 == T.cast(arg2_strides[5], "int32")) + and (16 == T.cast(arg2_strides[4], "int32")) ) - and (256 == T.cast(T.load("int64", arg2_strides, 3), "int32")) + and (256 == T.cast(arg2_strides[3], "int32")) ) - and (8192 == T.cast(T.load("int64", arg2_strides, 2), "int32")) + and (8192 == T.cast(arg2_strides[2], "int32")) ) - and (114688 == T.cast(T.load("int64", arg2_strides, 1), "int32")) + and (114688 == T.cast(arg2_strides[1], "int32")) ) and ( - 1605632 == T.cast(T.load("int64", arg2_strides, 0), "int32") + 1605632 == T.cast(arg2_strides[0], "int32") ), "arg2.strides: expected to be compact array" T.evaluate(0) assert T.uint64(0) == T.tvm_struct_get( @@ -2655,9 +2559,9 @@ def vthread_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(16): - B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + T.float32(1) + B[j] = A.data[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) for j in range(16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * T.float32(2) + C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * T.float32(2) def test_vthread(): @@ -2936,7 +2840,7 @@ def test_rank0_buffers(): def rank0_block(a: T.handle) -> None: A = T.match_buffer(a, (), "float32") B = T.alloc_buffer((), "float32") - T.store(B.data, 0, T.load("float32", A.data, 0)) + T.store(B.data, 0, A.data[0]) with T.block("update") as []: T.reads([A[()]]) @@ -3082,10 +2986,10 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) # fmt: on @@ -3105,7 +3009,7 @@ def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_x, 128) reduce_temp0 = T.allocate([1], "float32", "local") with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A.data[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) @T.prim_func @@ -3117,7 +3021,7 @@ def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_x, 128) reduce_temp0 = T.allocate([1], "float32", "local") with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A.data[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) @T.prim_func From 417ee2ba8a8edc4a4bebbe16289254a9510d042e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jan 2022 15:36:50 -0600 Subject: [PATCH 035/177] Return buffer object from tvm.tir.script.scope_handler.Allocate Now that the load/store require buffer objects, allocation should also return a buffer object to be used. --- python/tvm/script/tir/__init__.pyi | 2 +- python/tvm/script/tir/scope_handler.py | 31 ++++++--- src/printer/tvmscript_printer.cc | 67 ++++++++++++++----- .../unittest/test_tvmscript_roundtrip.py | 4 +- 4 files changed, 73 insertions(+), 31 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index ac4ee3018f7c..0593236512a1 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -311,7 +311,7 @@ def allocate( scope: str, condition: Union[PrimExpr, builtins.bool] = True, annotations: Optional[Mapping[str, Object]] = None, -) -> Var: ... +) -> Buffer: ... def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... def realize( buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index fc953771bf21..a22dc4f8afe8 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -111,10 +111,16 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) + # Currently, allocate nodes should only occur after buffer + # flattening has been applied. This can be simplified in + # the future by having the AllocateNode hold a buffer + # object directly. + flattened = self.buffer.get_flattened_buffer() + return tvm.tir.Allocate( - self.buffer_var, - dtype, - extents, + self.buffer.data, + flattened.dtype, + flattened.shape, condition, self.body, annotations=annotations, @@ -122,7 +128,7 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) ) super().__init__(allocate, concise_scope=True, def_symbol=True) - self.buffer_var = None + self.buffer = None def enter_scope( self, @@ -146,15 +152,20 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var( + def setup_buffer( extents, dtype, scope, condition=True, annotations=None, span: Span = None ): - """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + """Setup buffer object for a given type.""" + self.buffer = tvm.tir.decl_buffer( + shape=extents, + dtype=dtype, + name=name, + scope=scope, + span=span, + ) - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) @register diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 0d6c6e5deeba..c4b48e98de73 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -22,6 +22,7 @@ * \brief Printer class to print Tensor IR to python syntax script */ +#include #include #include #include @@ -261,7 +262,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintRange(const RangeNode* op); Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); - Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body); + Doc PrintNonHeaderBufferDeclarations(const Array& aliasing_buffers); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); Doc PrintBlockVarRemaps(); @@ -888,16 +889,21 @@ Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_prece } Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { + if (!buffer_var_usage_.count(op->var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); + } + Array buffer_usage = buffer_var_usage_.Get(op->var).value_or({}); + Doc doc; if (current_num_ != num_child_ - 1) { doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):"; - doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body) - << PrintBody(op->body)); + doc << Doc::Indent( + 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); } else { if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) << Doc::NewLine(); - doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body); + doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); } return doc; } @@ -985,7 +991,39 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - var_not_in_headers_.insert(op->buffer_var.get()); + auto is_exact_match = [](Buffer a, Buffer b) { + if (a->dtype != b->dtype) return false; + if (a->shape.size() != b->shape.size()) return false; + + arith::Analyzer analyzer; + for (size_t i = 0; i < a->shape.size(); i++) { + if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) { + return false; + } + } + return true; + }; + + if (!buffer_var_usage_.count(op->buffer_var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); + } + Array buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({}); + + // If the buffer allocated via T.allocate is an exact match to the + // usage of the buffer later on, then that buffer is the return + // value of T.allocate, and no T.buffer_decl statement is needed. + Buffer alloc_buf(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0, 0, + kDefault); + Array aliasing_buffers; + for (const auto& buf : buffer_usage) { + if (is_exact_match(buf, alloc_buf)) { + alloc_buf = buf; + } else { + aliasing_buffers.push_back(buf); + } + } + buf_not_in_headers_.insert(alloc_buf.get()); + var_not_in_headers_.insert(alloc_buf->data.get()); auto storage_scope = GetPtrStorageScope(op->buffer_var); Doc func_call; @@ -1003,13 +1041,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << func_call << " as " << Print(op->buffer_var) << ":"; - doc << Doc::Indent(4, Doc::NewLine() - << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) - << PrintBody(op->body)); + doc << "with " << func_call << " as " << Print(alloc_buf) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) + << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); - doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body); + doc << Print(alloc_buf) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); return doc; @@ -1518,13 +1555,9 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } -Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) { - if (!buffer_var_usage_.count(buffer_var)) { - buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body); - } - Array buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({}); +Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array& aliasing_buffers) { Doc decls; - for (const auto& buf_usage : buffer_usage) { + for (const auto& buf_usage : aliasing_buffers) { decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl(" << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine(); buf_not_in_headers_.insert(buf_usage.get()); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index da685a949527..35a63a58e89c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3194,9 +3194,7 @@ def test_T_ptr_let_statement(): @T.prim_func def func_T_ptr_allocate() -> None: - A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global") - A = T.buffer_decl([1024], dtype="float32", data=A_data) - + A = T.allocate([1024], "float32", "global") A[0] = 0.0 From e80914b33eb7d1254a3dcf744c4f61605882ff16 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jan 2022 16:06:25 -0600 Subject: [PATCH 036/177] Added .astype to tvm.script.tir.node.BufferSlice Since `buf[i]` returns a `BufferSlice`, this lets the TIR examples that use `buf[i].astype('out_dtype')` continue functioning. --- python/tvm/script/tir/node.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index cfbc668946a0..11aa411f6b9d 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -152,3 +152,6 @@ def asobject(self) -> BufferLoad: indices = [s.start for s in self.slices] return BufferLoad(self.buffer, indices, span=self.span) + + def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: + return self.asobject().astype(dtype) From d87a5611e2d2e6a483963f570ce1ce181f8d114c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 13 Jan 2022 18:59:56 -0600 Subject: [PATCH 037/177] Replacing all T.store TIR calls. --- python/tvm/script/tir/node.py | 3 +- .../unittest/test_target_codegen_llvm.py | 2 +- ...t_tir_analysis_detect_buffer_access_lca.py | 4 +- tests/python/unittest/test_tir_intrin.py | 4 +- .../unittest/test_tir_lower_match_buffer.py | 2 +- .../test_tir_schedule_cache_read_write.py | 6 +- .../test_tir_schedule_compute_inline.py | 6 +- .../unittest/test_tir_schedule_reorder.py | 4 +- .../unittest/test_tir_schedule_split_fuse.py | 6 +- ...est_tir_transform_compact_buffer_region.py | 14 +- ..._tir_transform_convert_for_loops_serial.py | 8 +- .../test_tir_transform_flatten_buffer.py | 26 +- .../test_tir_transform_loop_partition.py | 4 +- tests/python/unittest/test_tir_usmp_algo.py | 50 +- ...st_tir_usmp_analysis_extract_bufferinfo.py | 264 ++++---- ...orm_convert_pool_allocations_to_offsets.py | 223 +++---- tests/python/unittest/test_tir_usmp_utils.py | 22 +- .../unittest/test_tvmscript_roundtrip.py | 589 +++++++----------- 18 files changed, 549 insertions(+), 688 deletions(-) diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index 11aa411f6b9d..42c8754ded95 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -96,7 +96,8 @@ def check_index(index: Union[int, PrimExpr]): if index < 0: report_error("Negative index is not allowed during buffer access", span) elif isinstance(index, PrimExpr): - if index.dtype != "int32": + element_dtype = index.dtype.split("x", maxsplit=1)[0] + if element_dtype != "int32": report_error( "index expected an int32 type PrimExpr but got " + str(index.dtype), index.span, diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c09161d67ce6..c2a7326d517a 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -925,7 +925,7 @@ def threadpool_nested_parallel_loop( T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): - T.store(B.data, i * 4 + j, A.data[i * 4 + j] * 2.0) + B.data[i * 4 + j] = A.data[i * 4 + j] * 2.0 with pytest.raises(tvm.TVMError) as e: tvm.build({"llvm": tvm.IRModule.from_expr(threadpool_nested_parallel_loop)}) diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 6645983f5211..9b688bb857f2 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -54,7 +54,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): - T.store(A, i * 16 + j, 1) + A[i * 16 + j] = 1 for i in range(0, 16): for j in range(0, 16): T.evaluate(A[i * 16 + j]) @@ -70,7 +70,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: @T.prim_func def lca_is_func_root(a: T.handle) -> None: A = T.match_buffer(a, [0, 0], "float32") - A.data[0] = 1.0 + A[0] = 1.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 444cc9121e77..b800a6d2109c 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -236,9 +236,7 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: ) # body for i in T.serial(0, n): - d_1.data[(i * stride_3)] = ( - A_1.data[(i * stride)] * B_1.data[(i * stride_1)] - ) + C_1.data[(i * stride_2)] + d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[(i * stride_2)] def test_fma(): diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 0c93f3a50382..3a9af20a41b5 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -476,7 +476,7 @@ def fail_match_store(a: T.handle) -> None: T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], ()) - sub_A.data[0] = 1 + sub_A[0] = 1 @T.prim_func diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 00bcb710e24b..7feb82a095fe 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -80,7 +80,7 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = A.data[vi * 128 + vj] + D[vi * 128 + vj] = A[vi * 128 + vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -272,7 +272,7 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = A_global.data[vi * 128 + vj] + D[vi * 128 + vj] = A_global[vi * 128 + vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -481,7 +481,7 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = A.data[vi * 128 + vj] + D_global[vi * 128 + vj] = A[vi * 128 + vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index a098b2322792..3e31e9dd0d73 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -183,7 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - C[vi, vj] = B.data[vi * 128 + vj] + 1.0 + C[vi, vj] = B[vi * 128 + vj] + 1.0 @T.prim_func @@ -200,8 +200,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = B.data[vi * 16 + vj] + 1.0 + C[vi * 128 + vj] = B[vi, vj] + 1.0 + C[vi, vj] = B[vi * 16 + vj] + 1.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index fd2d82d1ff1f..bfa469b86fbe 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -153,7 +153,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) + A[vi * 16 + vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -171,7 +171,7 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) + A[vi * 16 + vj] = 1 for j, i in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 84ececebbcba..576a8a99ef69 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -273,7 +273,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) + A[vi * 16 + vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -292,7 +292,7 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, ((vi * 16) + vj), 1, 1) + A[((vi * 16) + vj)] = 1 for i_j_fused in T.serial(0, 256): with T.block("B"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) @@ -312,7 +312,7 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) - T.store(A.data, ((vi * 16) + vj), 1, 1) + A[((vi * 16) + vj)] = 1 for i, j0, j1 in T.grid(16, 4, 4): with T.block("B"): vi = T.axis.S(16, i) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 145b9af9eddd..9a3799ba5f46 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -80,7 +80,7 @@ def unschedulable_func(a: T.handle, c: T.handle) -> None: T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - T.store(B.data, i * 16 + j, A[i, j] + 1.0) + B[i * 16 + j] = A[i, j]y + 1.0 for j in range(0, 16): C[i, j] = B[i, j] * 2.0 @@ -251,7 +251,7 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(4, 8): D[k, j] = 1.0 for k in range(2, 4): - T.store(B.data, j, A[i, j] + D[k, j]) + B[j] = A[i, j] + Dy[k, j] for j in range(3, 5): with T.block() as []: T.reads(B[i, j]) @@ -281,7 +281,7 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(4, 8): D[k - 2, 0] = 1.0 for k in range(2, 4): - T.store(B.data, j, A[i, j] + D[k - 2, 0]) + B[j] = A[i, j] + D[k -y 2, 0] for j in range(3, 5): with T.block() as []: T.reads(B[0, j]) @@ -476,13 +476,13 @@ def opaque_access_annotated_func(a: T.handle) -> None: # no annotation, opaque access will cover full region T.reads([]) T.writes([]) - T.store(B.data, i, "float32", A[i]) + B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) - T.store(C.data, i, B.data[i]) + C[i] = B[i] @T.prim_func @@ -496,13 +496,13 @@ def compacted_opaque_access_annotated_func(a: T.handle) -> None: # no annotation, opaque access will cover full region T.reads([]) T.writes([]) - T.store(B.data, i, "float32", A[i]) + B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) - T.store(C.data, i, B.data[i]) + C[i] = B[i] def test_elementwise(): diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index 98b894fbf733..862fb31ee40a 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -34,14 +34,14 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T. PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), placeholder_33.data[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)], True) + PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") - T.store(Conv2dOutput_3, 0, 0, True) + Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34.data[((rc_3*16) + ax3_2)], "int32"))), True) - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35.data[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))), True + T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True # fmt: on diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index be8c3c1f656f..7dab0589dd9e 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -55,9 +55,9 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: for i in T.serial(0, 16): B_new = T.allocate([16], "float32", "global") for j in T.serial(0, 16): - B_new[j] = A.data[((i * 16) + j)] + 1.0 + B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): - C.data[((i * 16) + j)] = B_new[j] * 2.0 + C[((i * 16) + j)] = B_new[j] * 2.0 @T.prim_func @@ -97,9 +97,9 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(0, 16): - B[j] = A.data[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 @T.prim_func @@ -132,9 +132,9 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> for i in range(0, n): B = T.allocate([m], "float32", "global") for j in range(0, m): - B[j] = A.data[i * m + j] + 1.0 + B[j] = A[i * m + j] + 1.0 for j in range(0, m): - C.data[i * m + j] = B[j] * 2.0 + C[i * m + j] = B[j] * 2.0 @T.prim_func @@ -157,7 +157,7 @@ def flattened_predicate_func(a: T.handle, c: T.handle) -> None: for i, j in T.grid(5, 7): if i * 7 + j < 32: - C.data[i * 7 + j] = A.data[i * 7 + j] + 1.0 + C[i * 7 + j] = A[i * 7 + j] + 1.0 @T.prim_func @@ -178,7 +178,7 @@ def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for x, z in T.grid(4, 8): - C.data[x * 8 + z] = A.data[x * 8 + z] + 1.0 + C[x * 8 + z] = A[x * 8 + z] + 1.0 @T.prim_func @@ -205,9 +205,9 @@ def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: for i in range(0, 32): B = T.allocate((32,), "float32", "global") C = T.allocate((32,), "float32", "global") - B[i] = A.data[i] + 1.0 - C[i] = A.data[i] + B[i] - D.data[i] = C[i] * 2.0 + B[i] = A[i] + 1.0 + C[i] = A[i] + B[i] + D[i] = C[i] * 2.0 @T.prim_func @@ -241,10 +241,10 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: B_new = T.allocate([68], "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): - B_new[i1 * 17 + j] = A.data[i0 * 64 + i1 * 16 + j] + 1.0 + B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 for i1 in T.serial(0, 4): for j in T.serial(0, 16): - C.data[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 + C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 2422d2ebd9c5..53011975ca21 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -545,9 +545,9 @@ def partitioned_concat(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16], dtype="float32") C = T.match_buffer(c, [32], dtype="float32") for i in T.serial(0, 16): - T.store(C.data, i, A.data[i], True) + C[i] = A[i] for i in T.serial(0, 16): - T.store(C.data, i + 16, B.data[i + 16], True) + C[i + 16] = B[i + 16] def test_explicit_partition_hint(): diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 192666d115e4..96def166ea43 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -304,7 +304,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -318,15 +318,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -339,12 +339,12 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -423,7 +423,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2.data[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3.data[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16"), True) + T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: @@ -436,15 +436,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13.data[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16"), True) + PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1, ff_1, 0, True) + Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1, ff_1, Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14.data[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32"), True) + Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15.data[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: @@ -457,16 +457,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, placeholder_19.data[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2], True) + PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, 0, True) + Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20.data[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32"), True) + Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21.data[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: @@ -480,16 +480,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, placeholder_29.data[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3], True) + PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, 0, True) + Conv2dOutput_3[ff_3] = 0 for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27.data[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32"), True) + Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26.data[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28.data[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8"), True) + T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: @@ -519,15 +519,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, placeholder_7.data[i0_i1_fused * 4800 + i2 * 64 + i3], True) + PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): - T.store(Conv2dOutput, ff, 0, True) + Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8.data[rc * 64 + ff], "int32"), True) + Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9.data[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") __tvm_meta__ = None # fmt: on diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 32b443ecbc14..0d39e2eecf50 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -105,7 +105,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -119,15 +119,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -140,12 +140,12 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -215,17 +215,17 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71.data[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16"), True) + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): dummy_allocate = T.allocate([1], "int32", "global") for ax3_outer_4 in T.serial(0, 3): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ff_4 in T.serial(0, 64): - T.store(Conv2dOutput_8, ff_4, 0, True) + Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72.data[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))), True) + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73.data[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -256,17 +256,17 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71.data[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16"), True) + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): dummy_allocate = T.allocate([1], "int32", "global") for ax3_outer_4 in T.serial(0, 3): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ff_4 in T.serial(0, 64): - T.store(Conv2dOutput_8, ff_4, 0, True) + Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72.data[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))), True) + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73.data[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -336,9 +336,9 @@ def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) for ax0_ax1_fused in T.serial(0, 28): for ax2 in T.serial(0, 28): for ax3_outer_init, ax3_inner_init in T.grid(3, 64): - T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init), T.uint8(0), True) + tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init)] = T.uint8(0) for rv0_rv1_fused, ax3_outer, ax3_inner in T.grid(9, 3, 64): - T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner), T.max(tensor_1.data[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)], T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), placeholder_1.data[((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)], T.uint8(0), dtype="uint8")), True) + tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)] = T.max(tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)], T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), placeholder_1[((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)], T.uint8(0), dtype="uint8")) @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: @@ -350,7 +350,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: @@ -361,7 +361,7 @@ def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None # body for ax0_ax1_fused_2 in T.serial(0, 28): for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): - T.store(T_cast_1.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2), T.cast(placeholder_7.data[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)], "int16"), True) + T_cast_1[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)] = T.cast(placeholder_7[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)], "int16") @T.prim_func def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: @@ -375,7 +375,7 @@ def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.h # body for ax0_ax1_fused_3 in T.serial(0, 28): for ax2_3, ax3 in T.grid(28, 256): - T.store(T_concat_1.data, (((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3), T.if_then_else((224 <= ax3), placeholder_14.data[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)], T.if_then_else((192 <= ax3), placeholder_15.data[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)], T.if_then_else((64 <= ax3), placeholder_13.data[((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)], placeholder_12.data[(((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)], dtype="uint8"), dtype="uint8"), dtype="uint8"), True) + T_concat_1[(((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3)] = T.if_then_else((224 <= ax3), placeholder_14[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)], T.if_then_else((192 <= ax3), placeholder_15[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)], T.if_then_else((64 <= ax3), placeholder_13[((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)], placeholder_12[(((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)], dtype="uint8"), dtype="uint8"), dtype="uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: @@ -389,15 +389,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place PaddedInput = T.allocate([200704], "int16", "global") for i0_i1_fused in T.serial(0, 56): for i2, i3 in T.grid(56, 64): - T.store(PaddedInput, (((i0_i1_fused*3584) + (i2*64)) + i3), placeholder_19.data[(((i0_i1_fused*3584) + (i2*64)) + i3)], True) + PaddedInput[(((i0_i1_fused*3584) + (i2*64)) + i3)] = placeholder_19[(((i0_i1_fused*3584) + (i2*64)) + i3)] for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): - T.store(Conv2dOutput, ff, 0, True) + Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, (Conv2dOutput[ff] + (T.cast(PaddedInput[((ax0_ax1_fused_ax2_fused*64) + rc)], "int32")*T.cast(placeholder_20.data[((rc*64) + ff)], "int32"))), True) + Conv2dOutput[ff] = (Conv2dOutput[ff] + (T.cast(PaddedInput[((ax0_ax1_fused_ax2_fused*64) + rc)], "int32")*T.cast(placeholder_20[((rc*64) + ff)], "int32"))) for ax3_inner_3 in T.serial(0, 64): - T.store(T_cast_3.data, ((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput[ax3_inner_3] + placeholder_21.data[ax3_inner_3]), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_3[((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput[ax3_inner_3] + placeholder_21[ax3_inner_3]), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: @@ -411,14 +411,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla PaddedInput_1 = T.allocate([150528], "int16", "global") for i0_i1_fused_1 in T.serial(0, 28): for i2_1, i3_1 in T.grid(28, 192): - T.store(PaddedInput_1, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1), placeholder_25.data[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)], True) + PaddedInput_1[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] = placeholder_25[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): Conv2dOutput_1 = T.allocate([1], "int32", "global") for ax3_1 in T.serial(0, 96): - T.store(Conv2dOutput_1, 0, 0, True) + Conv2dOutput_1[0] = 0 for rc_1 in T.serial(0, 192): - T.store(Conv2dOutput_1, 0, (Conv2dOutput_1[0] + (T.cast(PaddedInput_1[((ax0_ax1_fused_ax2_fused_1*192) + rc_1)], "int32")*T.cast(placeholder_26.data[((rc_1*96) + ax3_1)], "int32"))), True) - T.store(T_cast_5.data, ((ax0_ax1_fused_ax2_fused_1*96) + ax3_1), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_1[0] + placeholder_27.data[ax3_1]), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_1[0] = (Conv2dOutput_1[0] + (T.cast(PaddedInput_1[((ax0_ax1_fused_ax2_fused_1*192) + rc_1)], "int32")*T.cast(placeholder_26[((rc_1*96) + ax3_1)], "int32"))) + T_cast_5[((ax0_ax1_fused_ax2_fused_1*96) + ax3_1)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_1[0] + placeholder_27[ax3_1]), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -431,12 +431,12 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: @@ -450,15 +450,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placehol PaddedInput_2 = T.allocate([150528], "int16", "global") for i0_i1_fused_2 in T.serial(0, 28): for i2_2, i3_2 in T.grid(28, 192): - T.store(PaddedInput_2, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2), placeholder_33.data[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)], True) + PaddedInput_2[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] = placeholder_33[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_1, 0, True) + Conv2dOutput_2[ff_1] = 0 for rc_2 in T.serial(0, 192): - T.store(Conv2dOutput_2, ff_1, (Conv2dOutput_2[ff_1] + (T.cast(PaddedInput_2[((ax0_ax1_fused_ax2_fused_2*192) + rc_2)], "int32")*T.cast(placeholder_34.data[((rc_2*64) + ff_1)], "int32"))), True) + Conv2dOutput_2[ff_1] = (Conv2dOutput_2[ff_1] + (T.cast(PaddedInput_2[((ax0_ax1_fused_ax2_fused_2*192) + rc_2)], "int32")*T.cast(placeholder_34[((rc_2*64) + ff_1)], "int32"))) for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_2[ax3_inner_4] + placeholder_35.data[ax3_inner_4]), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8"), True) + T_cast_9[((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_2[ax3_inner_4] + placeholder_35[ax3_inner_4]), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: @@ -471,12 +471,12 @@ def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_1 for ax0_ax1_fused_6 in T.serial(0, 28): for ax2_6 in T.serial(0, 28): for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): - T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1), T.uint8(0), True) + tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1)] = T.uint8(0) for rv0_rv1_fused_2, ax3_outer_2, ax3_inner_5 in T.grid(9, 3, 64): - T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5), T.max(tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)], T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), placeholder_37.data[(((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)], T.uint8(0), dtype="uint8")), True) + tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)] = T.max(tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)], T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), placeholder_37[(((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_7 in T.serial(0, 28): for ax2_7, ax3_4 in T.grid(28, 192): - T.store(T_cast_11.data, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4), T.cast(tensor_3[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)], "int16"), True) + T_cast_11[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)] = T.cast(tensor_3[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)], "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: @@ -490,14 +490,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.serial(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), placeholder_41.data[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)], True) + PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] = placeholder_41[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): Conv2dOutput_3 = T.allocate([1], "int32", "global") for ax3_5 in T.serial(0, 32): - T.store(Conv2dOutput_3, 0, 0, True) + Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_42.data[((rc_3*32) + ax3_5)], "int32"))), True) - T.store(T_cast_13.data, ((ax0_ax1_fused_ax2_fused_3*32) + ax3_5), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_43.data[ax3_5]), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_42[((rc_3*32) + ax3_5)], "int32"))) + T_cast_13[((ax0_ax1_fused_ax2_fused_3*32) + ax3_5)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_43[ax3_5]), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: @@ -511,14 +511,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(pla PaddedInput_4 = T.allocate([150528], "int16", "global") for i0_i1_fused_4 in T.serial(0, 28): for i2_4, i3_4 in T.grid(28, 192): - T.store(PaddedInput_4, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4), placeholder_47.data[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)], True) + PaddedInput_4[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] = placeholder_47[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): Conv2dOutput_4 = T.allocate([1], "int32", "global") for ax3_6 in T.serial(0, 16): - T.store(Conv2dOutput_4, 0, 0, True) + Conv2dOutput_4[0] = 0 for rc_4 in T.serial(0, 192): - T.store(Conv2dOutput_4, 0, (Conv2dOutput_4[0] + (T.cast(PaddedInput_4[((ax0_ax1_fused_ax2_fused_4*192) + rc_4)], "int32")*T.cast(placeholder_48.data[((rc_4*16) + ax3_6)], "int32"))), True) - T.store(T_cast_15.data, ((ax0_ax1_fused_ax2_fused_4*16) + ax3_6), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_4[0] + placeholder_49.data[ax3_6]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_4[0] = (Conv2dOutput_4[0] + (T.cast(PaddedInput_4[((ax0_ax1_fused_ax2_fused_4*192) + rc_4)], "int32")*T.cast(placeholder_48[((rc_4*16) + ax3_6)], "int32"))) + T_cast_15[((ax0_ax1_fused_ax2_fused_4*16) + ax3_6)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_4[0] + placeholder_49[ax3_6]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: @@ -532,14 +532,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed PaddedInput_5 = T.allocate([14400], "int16", "global") for i0_i1_fused_5 in T.serial(0, 30): for i2_5, i3_5 in T.grid(30, 16): - T.store(PaddedInput_5, (((i0_i1_fused_5*480) + (i2_5*16)) + i3_5), T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), placeholder_53.data[((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)], T.int16(0), dtype="int16"), True) + PaddedInput_5[(((i0_i1_fused_5*480) + (i2_5*16)) + i3_5)] = T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), placeholder_53[((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): Conv2dOutput_5 = T.allocate([1], "int32", "global") for ax3_7 in T.serial(0, 32): - T.store(Conv2dOutput_5, 0, 0, True) + Conv2dOutput_5[0] = 0 for ry, rx, rc_5 in T.grid(3, 3, 16): - T.store(Conv2dOutput_5, 0, (Conv2dOutput_5[0] + (T.cast(PaddedInput_5[(((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)], "int32")*T.cast(placeholder_54.data[((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)], "int32"))), True) - T.store(T_cast_17.data, ((ax0_ax1_fused_ax2_fused_5*32) + ax3_7), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_5[0] + placeholder_55.data[ax3_7]), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + Conv2dOutput_5[0] = (Conv2dOutput_5[0] + (T.cast(PaddedInput_5[(((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)], "int32")*T.cast(placeholder_54[((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)], "int32"))) + T_cast_17[((ax0_ax1_fused_ax2_fused_5*32) + ax3_7)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_5[0] + placeholder_55[ax3_7]), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: @@ -553,16 +553,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed PaddedInput_6 = T.allocate([86400], "int16", "global") for i0_i1_fused_6 in T.serial(0, 30): for i2_6, i3_6 in T.grid(30, 96): - T.store(PaddedInput_6, (((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6), T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), placeholder_59.data[((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)], T.int16(0), dtype="int16"), True) + PaddedInput_6[(((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6)] = T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), placeholder_59[((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): Conv2dOutput_6 = T.allocate([64], "int32", "global") for ax3_outer_3 in T.serial(0, 2): for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_6, ff_2, 0, True) + Conv2dOutput_6[ff_2] = 0 for ry_1, rx_1, rc_6 in T.grid(3, 3, 96): - T.store(Conv2dOutput_6, ff_2, (Conv2dOutput_6[ff_2] + (T.cast(PaddedInput_6[(((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)], "int32")*T.cast(placeholder_60.data[(((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)], "int32"))), True) + Conv2dOutput_6[ff_2] = (Conv2dOutput_6[ff_2] + (T.cast(PaddedInput_6[(((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)], "int32")*T.cast(placeholder_60[(((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)], "int32"))) for ax3_inner_6 in T.serial(0, 64): - T.store(T_cast_19.data, (((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_6[ax3_inner_6] + placeholder_61.data[((ax3_outer_3*64) + ax3_inner_6)]), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + T_cast_19[(((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_6[ax3_inner_6] + placeholder_61[((ax3_outer_3*64) + ax3_inner_6)]), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -576,15 +576,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: @@ -598,16 +598,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): for i2_8, i3_8 in T.grid(58, 64): - T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71.data[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16"), True) + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): Conv2dOutput_8 = T.allocate([64], "int32", "global") for ax3_outer_4 in T.serial(0, 3): for ff_4 in T.serial(0, 64): - T.store(Conv2dOutput_8, ff_4, 0, True) + Conv2dOutput_8[ff_4] = 0 for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): - T.store(Conv2dOutput_8, ff_4, (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72.data[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))), True) + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) for ax3_inner_8 in T.serial(0, 64): - T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73.data[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -1111,7 +1111,7 @@ def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_tran T_layout_trans_1 = T.match_buffer(T_layout_trans, [1, 1, 24, 12, 3], dtype="float32") # body for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): - T.store(T_layout_trans_1.data, ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner, placeholder_1.data[ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3], True) + T_layout_trans_1[ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner] = placeholder_1[ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3] @T.prim_func def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: @@ -1123,60 +1123,60 @@ def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeh # body data_pad = T.allocate([1, 1, 26, 14, 3], "float32", "global") for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): - T.store(data_pad, i0_i1_fused_i2_fused * 42 + i3 * 3 + i4, T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4.data[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32"), True) + data_pad[i0_i1_fused_i2_fused * 42 + i3 * 3 + i4] = T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32") for n_oc_chunk_fused_oh_fused in T.serial(0, 24): conv2d_NCHWc_global = T.allocate([1, 1, 1, 12, 3], "float32", "global") for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 3, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 3] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 6, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 6] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 9, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 9] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 12, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 12] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 15, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 15] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 18, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 18] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 21, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 21] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 24, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 24] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 27, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 27] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 30, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 30] = T.float32(0) for oc_block_c_init in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c_init + 33, T.float32(0), True) + conv2d_NCHWc_global[oc_block_c_init + 33] = T.float32(0) for kh, kw, ic_inner in T.grid(3, 3, 3): for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c, conv2d_NCHWc_global[oc_block_c] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c] = conv2d_NCHWc_global[oc_block_c] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 3, conv2d_NCHWc_global[oc_block_c + 3] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 3] = conv2d_NCHWc_global[oc_block_c + 3] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 6, conv2d_NCHWc_global[oc_block_c + 6] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 6] = conv2d_NCHWc_global[oc_block_c + 6] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 9, conv2d_NCHWc_global[oc_block_c + 9] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 9] = conv2d_NCHWc_global[oc_block_c + 9] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 12, conv2d_NCHWc_global[oc_block_c + 12] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 12] = conv2d_NCHWc_global[oc_block_c + 12] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 15, conv2d_NCHWc_global[oc_block_c + 15] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 15] = conv2d_NCHWc_global[oc_block_c + 15] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 18, conv2d_NCHWc_global[oc_block_c + 18] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 18] = conv2d_NCHWc_global[oc_block_c + 18] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 21, conv2d_NCHWc_global[oc_block_c + 21] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 21] = conv2d_NCHWc_global[oc_block_c + 21] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 24, conv2d_NCHWc_global[oc_block_c + 24] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 24] = conv2d_NCHWc_global[oc_block_c + 24] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 27, conv2d_NCHWc_global[oc_block_c + 27] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 27] = conv2d_NCHWc_global[oc_block_c + 27] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 30, conv2d_NCHWc_global[oc_block_c + 30] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 30] = conv2d_NCHWc_global[oc_block_c + 30] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for oc_block_c in T.serial(0, 3): - T.store(conv2d_NCHWc_global, oc_block_c + 33, conv2d_NCHWc_global[oc_block_c + 33] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33] * placeholder_5.data[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c], True) + conv2d_NCHWc_global[oc_block_c + 33] = conv2d_NCHWc_global[oc_block_c + 33] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] for ow_inner, oc_block in T.grid(12, 3): - T.store(conv2d_NCHWc_1.data, n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block, conv2d_NCHWc_global[ow_inner * 3 + oc_block], True) + conv2d_NCHWc_1[n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block] = conv2d_NCHWc_global[ow_inner * 3 + oc_block] @T.prim_func def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: @@ -1192,20 +1192,20 @@ def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle for ax0_ax1_fused_ax2_fused in T.serial(0, 72): T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: - T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): - T.store(T_softmax_maxelem, 0, T.max(T_softmax_maxelem[0], placeholder_11.data[ax0_ax1_fused_ax2_fused * 12 + k]), True) + T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_11[ax0_ax1_fused_ax2_fused * 12 + k]) T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") for i3 in T.serial(0, 12): - T.store(T_softmax_exp, i3, T.exp(placeholder_11.data[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32"), True) + T_softmax_exp[i3] = T.exp(placeholder_11[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") - T.store(T_softmax_expsum, 0, T.float32(0), True) + T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): - T.store(T_softmax_expsum, 0, T_softmax_expsum[0] + T_softmax_exp[k], True) + T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] for i3 in T.serial(0, 12): - T.store(T_softmax_norm, i3, T_softmax_exp[i3] / T_softmax_expsum[0], True) + T_softmax_norm[i3] = T_softmax_exp[i3] / T_softmax_expsum[0] for ax3 in T.serial(0, 12): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused * 12 + ax3, (placeholder_12.data[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] + placeholder_13.data[T.floordiv(ax0_ax1_fused_ax2_fused, 24)]) * placeholder_14.data[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] + placeholder_15.data[T.floordiv(ax0_ax1_fused_ax2_fused, 24)], True) + T_add_1[ax0_ax1_fused_ax2_fused * 12 + ax3] = (placeholder_12[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] + placeholder_13[T.floordiv(ax0_ax1_fused_ax2_fused, 24)]) * placeholder_14[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] + placeholder_15[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] @T.prim_func def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: @@ -1219,56 +1219,56 @@ def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, compute = T.allocate([8, 6], "float32", "global") with T.allocate([8, 6], "float32", "global") as compute_global: for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init, T.float32(0), True) + compute_global[x_c_init] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 6, T.float32(0), True) + compute_global[x_c_init + 6] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 12, T.float32(0), True) + compute_global[x_c_init + 12] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 18, T.float32(0), True) + compute_global[x_c_init + 18] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 24, T.float32(0), True) + compute_global[x_c_init + 24] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 30, T.float32(0), True) + compute_global[x_c_init + 30] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 36, T.float32(0), True) + compute_global[x_c_init + 36] = T.float32(0) for x_c_init in T.serial(0, 6): - T.store(compute_global, x_c_init + 42, T.float32(0), True) + compute_global[x_c_init + 42] = T.float32(0) for k_outer in T.serial(0, 12): for x_c in T.serial(0, 6): - T.store(compute_global, x_c, compute_global[x_c] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c] = compute_global[x_c] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 6, compute_global[x_c + 6] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 6] = compute_global[x_c + 6] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 12, compute_global[x_c + 12] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 12] = compute_global[x_c + 12] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 18, compute_global[x_c + 18] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 18] = compute_global[x_c + 18] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 24, compute_global[x_c + 24] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 24] = compute_global[x_c + 24] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 30, compute_global[x_c + 30] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 30] = compute_global[x_c + 30] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 36, compute_global[x_c + 36] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 36] = compute_global[x_c + 36] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_c in T.serial(0, 6): - T.store(compute_global, x_c + 42, compute_global[x_c + 42] + placeholder_18.data[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84] * placeholder_19.data[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c], True) + compute_global[x_c + 42] = compute_global[x_c + 42] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner, compute_global[x_inner_inner], True) + compute[x_inner_inner] = compute_global[x_inner_inner] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 6, compute_global[x_inner_inner + 6], True) + compute[x_inner_inner + 6] = compute_global[x_inner_inner + 6] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 12, compute_global[x_inner_inner + 12], True) + compute[x_inner_inner + 12] = compute_global[x_inner_inner + 12] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 18, compute_global[x_inner_inner + 18], True) + compute[x_inner_inner + 18] = compute_global[x_inner_inner + 18] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 24, compute_global[x_inner_inner + 24], True) + compute[x_inner_inner + 24] = compute_global[x_inner_inner + 24] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 30, compute_global[x_inner_inner + 30], True) + compute[x_inner_inner + 30] = compute_global[x_inner_inner + 30] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 36, compute_global[x_inner_inner + 36], True) + compute[x_inner_inner + 36] = compute_global[x_inner_inner + 36] for x_inner_inner in T.serial(0, 6): - T.store(compute, x_inner_inner + 42, compute_global[x_inner_inner + 42], True) + compute[x_inner_inner + 42] = compute_global[x_inner_inner + 42] for ax0_inner_inner, ax1_inner_inner in T.grid(8, 6): - T.store(T_relu_1.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner, T.max(compute[ax0_inner_inner * 6 + ax1_inner_inner], T.float32(0)), True) + T_relu_1[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner] = T.max(compute[ax0_inner_inner * 6 + ax1_inner_inner], T.float32(0)) @T.prim_func def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: @@ -1278,7 +1278,7 @@ def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle T_reshape_1 = T.match_buffer(T_reshape, [72, 12], dtype="float32") # body for ax0, ax1_inner in T.grid(72, 12): - T.store(T_reshape_1.data, ax0 * 12 + ax1_inner, placeholder_21.data[ax0 * 12 + ax1_inner], True) + T_reshape_1[ax0 * 12 + ax1_inner] = placeholder_21[ax0 * 12 + ax1_inner] @T.prim_func def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: @@ -1288,7 +1288,7 @@ def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_tra T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [1, 3, 24, 12], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): - T.store(T_layout_trans_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, placeholder_23.data[ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused], True) + T_layout_trans_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_23[ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused] @T.prim_func def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: @@ -1298,7 +1298,7 @@ def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle T_reshape_3 = T.match_buffer(T_reshape_2, [1, 3, 24, 12], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): - T.store(T_reshape_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, placeholder_25.data[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner], True) + T_reshape_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_25[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] @T.prim_func def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: @@ -1311,20 +1311,20 @@ def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27 for ax0_ax1_fused_ax2_fused in T.serial(0, 72): T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: - T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): - T.store(T_softmax_maxelem, 0, T.max(T_softmax_maxelem[0], placeholder_28.data[ax0_ax1_fused_ax2_fused * 12 + k]), True) + T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_28[ax0_ax1_fused_ax2_fused * 12 + k]) T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") for i3 in T.serial(0, 12): - T.store(T_softmax_exp, i3, T.exp(placeholder_28.data[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32"), True) + T_softmax_exp[i3] = T.exp(placeholder_28[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") - T.store(T_softmax_expsum, 0, T.float32(0), True) + T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): - T.store(T_softmax_expsum, 0, T_softmax_expsum[0] + T_softmax_exp[k], True) + T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] for i3 in T.serial(0, 12): - T.store(T_softmax_norm, i3, T_softmax_exp[i3] / T_softmax_expsum[0], True) + T_softmax_norm[i3] = T_softmax_exp[i3] / T_softmax_expsum[0] for ax3 in T.serial(0, 12): - T.store(T_add_3.data, ax0_ax1_fused_ax2_fused * 12 + ax3, placeholder_29.data[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3], True) + T_add_3[ax0_ax1_fused_ax2_fused * 12 + ax3] = placeholder_29[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] @T.prim_func def run_model(data: T.handle, output: T.handle) -> None: diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 9adfb7639ada..1d42dade372e 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -79,7 +79,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -93,15 +93,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -114,12 +114,12 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -130,9 +130,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) # fmt: on @@ -146,8 +146,8 @@ def run_model(input: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.handle = T.address_of(slow_memory_1_buffer_var.data[1117472], dtype="handle") - sid_8_let: T.handle = T.address_of(slow_memory_1_buffer_var.data[0], dtype="handle") + sid_9_let: T.handle = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") + sid_8_let: T.handle = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @@ -159,14 +159,15 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - tensor_2_let: T.handle = T.address_of(fast_memory_6_buffer_var.data[0], dtype="handle") - for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): - for ax3_init in T.serial(0, 64): - T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True) - for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2], T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, placeholder_29.data[ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2], T.uint8(0), dtype="uint8")), True) - for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): - T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16"), True) + tensor_2_let = T.buffer_decl([200704], dtype="uint8") + with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): + for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): + for ax3_init in T.serial(0, 64): + tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2] = T.max(tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2], T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, placeholder_29[ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): + T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] = T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16") @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: @@ -177,7 +178,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): - T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(placeholder_4.data[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5.data[0], True) + T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: @@ -188,17 +189,19 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_7_let: T.handle = T.address_of(slow_memory_5_buffer_var.data[802816], dtype="handle") - for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): - T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65.data[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16"), True) - for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7_let: T.handle = T.address_of(fast_memory_4_buffer_var.data[0], dtype="handle") - for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7_let, ff_3, 0, True) - for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7_let, ff_3, Conv2dOutput_7_let[ff_3] + T.cast(PaddedInput_7_let[ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7], "int32") * T.cast(placeholder_66.data[ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3], "int32"), True) - for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_7_let[ax3_inner_7] + placeholder_67.data[ax3_inner_7], 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + PaddedInput_7_let = T.buffer_decl([307856], "int16") + with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): + for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): + PaddedInput_7_let[i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7] = T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7_let = T.buffer_decl([50176], "int32") + with T.let(Conv2dOutput_7_let.data, T.address_of(fast_memory_4_buffer_var[0], dtype="handle")): + for ff_3 in T.serial(0, 64): + Conv2dOutput_7_let[ff_3] = 0 + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + Conv2dOutput_7_let[ff_3] = Conv2dOutput_7_let[ff_3] + T.cast(PaddedInput_7_let[ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7], "int32") * T.cast(placeholder_66[ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3], "int32") + for ax3_inner_7 in T.serial(0, 64): + T_cast_21[ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7] = T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_7_let[ax3_inner_7] + placeholder_67[ax3_inner_7], 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") # fmt: on @@ -259,7 +262,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2.data[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3.data[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16"), True) + T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: @@ -272,15 +275,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13.data[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16"), True) + PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1, ff_1, 0, True) + Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1, ff_1, Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14.data[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32"), True) + Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(PaddedInput_1[T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast(placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15.data[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: @@ -293,16 +296,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, placeholder_19.data[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2], True) + PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, 0, True) + Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2, ff_2, Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20.data[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32"), True) + Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast(PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21.data[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: @@ -316,16 +319,16 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, placeholder_29.data[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3], True) + PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, 0, True) + Conv2dOutput_3[ff_3] = 0 for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3, ff_3, Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27.data[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32"), True) + Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26.data[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28.data[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8"), True) + T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: @@ -338,11 +341,11 @@ def run_model(input: T.handle, output: T.handle) -> None: sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: @@ -355,15 +358,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, placeholder_7.data[i0_i1_fused * 4800 + i2 * 64 + i3], True) + PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): - T.store(Conv2dOutput, ff, 0, True) + Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): - T.store(Conv2dOutput, ff, Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8.data[rc * 64 + ff], "int32"), True) + Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9.data[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") # fmt: on @@ -378,7 +381,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): - T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2.data[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3.data[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16"), True) + T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: @@ -389,18 +392,20 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_3_let: T.handle = T.address_of(global_workspace_5_buffer_var.data[6480000], dtype="handle") - for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): - T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, placeholder_29.data[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3], True) - for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3_let: T.handle = T.address_of(global_workspace_5_buffer_var.data[7200000], dtype="handle") - for ax3_outer_2 in T.serial(0, 4): - for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_3_let, ff_3, 0, True) - for rc_3 in T.serial(0, 64): - T.store(Conv2dOutput_3_let, ff_3, Conv2dOutput_3_let[ff_3] + T.cast(PaddedInput_3_let[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27.data[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32"), True) - for ax3_inner_4 in T.serial(0, 64): - T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26.data[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28.data[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8"), True) + PaddedInput_3_let = T.buffer_decl([360000], 'int16') + with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + PaddedInput_3_let[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3_let = T.buffer_decl([180064], 'int32') + with T.let(Conv2dOutput_3_let.data, T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle")): + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + Conv2dOutput_3_let[ff_3] = 0 + for rc_3 in T.serial(0, 64): + Conv2dOutput_3_let[ff_3] = Conv2dOutput_3_let[ff_3] + T.cast(PaddedInput_3_let[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") + for ax3_inner_4 in T.serial(0, 64): + T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: @@ -410,18 +415,20 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_2_let: T.handle = T.address_of(global_workspace_4_buffer_var.data[7200000], dtype="handle") - for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): - T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, placeholder_19.data[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2], True) - for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): - Conv2dOutput_2_let: T.handle = T.address_of(global_workspace_4_buffer_var.data[7920000], dtype="handle") - for ax3_outer_1 in T.serial(0, 4): - for ff_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, 0, True) - for rc_2 in T.serial(0, 64): - T.store(Conv2dOutput_2_let, ff_2, Conv2dOutput_2_let[ff_2] + T.cast(PaddedInput_2_let[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20.data[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32"), True) - for ax3_inner_3 in T.serial(0, 64): - T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21.data[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + PaddedInput_2_let = T.buffer_decl([360000], "int16") + with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + PaddedInput_2_let[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2_let = T.buffer_decl([64], 'int32') + with T.let(Conv2dOutput_2_let.data, T.address_of(global_workspace_4_buffer_var[7920000], dtype="handle")): + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + Conv2dOutput_2_let[ff_2] = 0 + for rc_2 in T.serial(0, 64): + Conv2dOutput_2_let[ff_2] = Conv2dOutput_2_let[ff_2] + T.cast(PaddedInput_2_let[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast(placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") + for ax3_inner_3 in T.serial(0, 64): + T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: @@ -431,17 +438,19 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_let: T.handle = T.address_of(global_workspace_2_buffer_var.data[7200000], dtype="handle") - for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): - T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, placeholder_7.data[i0_i1_fused * 4800 + i2 * 64 + i3], True) - for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): - Conv2dOutput_let: T.handle = T.address_of(global_workspace_2_buffer_var.data[7920000], dtype="handle") - for ff in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, 0, True) - for rc in T.serial(0, 64): - T.store(Conv2dOutput_let, ff, Conv2dOutput_let[ff] + T.cast(PaddedInput_let[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8.data[rc * 64 + ff], "int32"), True) - for ax3_inner_1 in T.serial(0, 64): - T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9.data[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + PaddedInput_let = T.buffer_decl([360000], "int16") + with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + PaddedInput_let[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput_let = T.buffer_decl([64], "int32") + with T.let(Conv2dOutput_let.data, T.address_of(global_workspace_2_buffer_var[7920000], dtype="handle")): + for ff in T.serial(0, 64): + Conv2dOutput_let[ff] = 0 + for rc in T.serial(0, 64): + Conv2dOutput_let[ff] = Conv2dOutput_let[ff] + T.cast(PaddedInput_let[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") + for ax3_inner_1 in T.serial(0, 64): + T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: @@ -451,17 +460,19 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_1_let: T.handle = T.address_of(global_workspace_3_buffer_var.data[0], dtype="handle") - for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): - T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13.data[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16"), True) - for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1_let: T.handle = T.address_of(global_workspace_3_buffer_var.data[7200000], dtype="handle") - for ff_1 in T.serial(0, 64): - T.store(Conv2dOutput_1_let, ff_1, 0, True) - for ry, rx, rc_1 in T.grid(3, 3, 64): - T.store(Conv2dOutput_1_let, ff_1, Conv2dOutput_1_let[ff_1] + T.cast(PaddedInput_1_let[ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1], "int32") * T.cast(placeholder_14.data[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32"), True) - for ax3_inner_2 in T.serial(0, 64): - T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15.data[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + PaddedInput_1_let = T.buffer_decl([3600000], "int16") + with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + PaddedInput_1_let[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1_let = T.buffer_decl([180064], "int32") + with T.let(Conv2dOutput_1_let.data, T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle")): + for ff_1 in T.serial(0, 64): + Conv2dOutput_1_let[ff_1] = 0 + for ry, rx, rc_1 in T.grid(3, 3, 64): + Conv2dOutput_1_let[ff_1] = Conv2dOutput_1_let[ff_1] + T.cast(PaddedInput_1_let[ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1], "int32") * T.cast(placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") + for ax3_inner_2 in T.serial(0, 64): + T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func def run_model(input: T.handle, global_workspace_0_var: T.handle, output: T.handle) -> None: @@ -469,10 +480,10 @@ def run_model(input: T.handle, global_workspace_0_var: T.handle, output: T.handl # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[5760000], dtype="handle") - sid_6_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[0], dtype="handle") - sid_7_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[6480000], dtype="handle") - sid_8_let: T.handle = T.address_of(global_workspace_0_buffer_var.data[6480000], dtype="handle") + sid_2_let: T.handle = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") + sid_6_let: T.handle = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") + sid_7_let: T.handle = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_8_let: T.handle = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 3d90687118b3..fa70cec9de4f 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -37,7 +37,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): - T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(placeholder_4.data[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5.data[0]), True) + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: @@ -51,15 +51,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): - T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65.data[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16"), True) + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): - T.store(Conv2dOutput_7, ff_3, 0, True) + Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): - T.store(Conv2dOutput_7, ff_3, (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66.data[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))), True) + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): - T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67.data[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: @@ -72,12 +72,12 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: @@ -88,9 +88,9 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 35a63a58e89c..df397fb81c73 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -93,134 +93,66 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - packedB = T.allocate([32768], "float32x32", "global") + packedB = T.allocate([32768], "float32", "global") for x in T.parallel(0, 32): for y in T.serial(0, 1024): - T.store( - packedB, - T.ramp(((x * 32768) + (y * 32)), 1, 32), - B_1.data[ - T.ramp(((y * 1024) + (x * 32)), 1, 32), - T.broadcast(True, 32), - ], - T.broadcast(True, 32), - ) + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[ + T.ramp(((y * 1024) + (x * 32)), 1, 32) + ] for x_outer in T.parallel(0, 32): C_global = T.allocate([1024], "float32", "global") for y_outer in T.serial(0, 32): for x_c_init in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c_init * 32), 1, 32), - T.broadcast(T.float32(0), 32), - T.broadcast(True, 32), - ) + C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) for k_outer in T.serial(0, 256): for x_c in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ] - + ( - T.broadcast( - A_1.data[ - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), - ], - 32, - ) - * packedB[ - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - T.broadcast(True, 32), - ] - ) - ), - T.broadcast(True, 32), + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), + ], + 32, + ) + * packedB[T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32)] ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ] - + ( - T.broadcast( - A_1.data[ - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 1 - ), - ], - 32, - ) - * packedB[ - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32), - T.broadcast(True, 32), - ] - ) - ), - T.broadcast(True, 32), + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 1), + ], + 32, + ) + * packedB[T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32)] ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ] - + ( - T.broadcast( - A_1.data[ - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 2 - ), - ], - 32, - ) - * packedB[ - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32), - T.broadcast(True, 32), - ] - ) - ), - T.broadcast(True, 32), + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 2), + ], + 32, + ) + * packedB[T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32)] ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - ( - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ] - + ( - T.broadcast( - A_1.data[ - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 3 - ), - ], - 32, - ) - * packedB[ - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32), - T.broadcast(True, 32), - ] - ) - ), - T.broadcast(True, 32), + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 3), + ], + 32, + ) + * packedB[T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32)] ) for x_inner in T.serial(0, 32): for y_inner in T.serial(0, 32): - C_1.data[ + C_1[ ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner) ] = C_global[((x_inner * 32) + y_inner)] @@ -375,15 +307,10 @@ def mmult( T.evaluate(T.tvm_throw_last_error(dtype="int32")) for x in T.parallel(0, 32): for y in T.serial(0, 1024): - T.store( - packedB, - T.ramp(((x * 32768) + (y * 32)), 1, 32), - B[ - T.ramp(((y * 1024) + (x * 32)), 1, 32), - T.broadcast(True, 32), - ], + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B[ + T.ramp(((y * 1024) + (x * 32)), 1, 32), T.broadcast(True, 32), - ) + ] for x_outer in T.parallel(0, 32): T.attr(C_global, "storage_scope", "global") T.attr(C_global, "storage_alignment", 128) @@ -395,136 +322,93 @@ def mmult( T.evaluate(T.tvm_throw_last_error(dtype="int32")) for y_outer in T.serial(0, 32): for x_c_init in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c_init * 32), 1, 32), - T.broadcast(T.float32(0), 32), - T.broadcast(True, 32), - ) + C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) for k_outer in T.serial(0, 256): for x_c in T.serial(0, 32): - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ), - ], - 32, - ), - packedB[ - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - T.broadcast(True, 32), - ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), ], - dtype="float32x32", + 32, ), - T.broadcast(True, 32), + packedB[ + T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), + T.broadcast(True, 32), + ], + C_global[ + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), + ], + dtype="float32x32", ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 1 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + + 1 ), - T.broadcast(True, 32), ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ], - dtype="float32x32", + 32, ), - T.broadcast(True, 32), + packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32), + T.broadcast(True, 32), + ], + C_global[ + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), + ], + dtype="float32x32", ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 2 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + + 2 ), - T.broadcast(True, 32), ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ], - dtype="float32x32", + 32, ), - T.broadcast(True, 32), + packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32), + T.broadcast(True, 32), + ], + C_global[ + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), + ], + dtype="float32x32", ) - T.store( - C_global, - T.ramp((x_c * 32), 1, 32), - T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - ( - ((x_outer * 32768) + (x_c * 1024)) - + (k_outer * 4) - ) - + 3 - ), - ], - 32, - ), - packedB[ - T.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + + 3 ), - T.broadcast(True, 32), - ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), ], - dtype="float32x32", + 32, ), - T.broadcast(True, 32), + packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32), + T.broadcast(True, 32), + ], + C_global[ + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), + ], + dtype="float32x32", ) for x_inner in T.serial(0, 32): for y_inner in T.serial(0, 32): @@ -1078,7 +962,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1119,7 +1003,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1160,7 +1044,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1201,7 +1085,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1242,7 +1126,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1283,7 +1167,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1324,7 +1208,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1365,7 +1249,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1406,7 +1290,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1447,7 +1331,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1488,7 +1372,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1529,7 +1413,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1570,7 +1454,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1611,7 +1495,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1652,7 +1536,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1690,7 +1574,7 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) and ((ax2 + T.floormod(bz, 14)) < 15) ), - A_1.data[ + A_1[ ( ( ( @@ -1715,158 +1599,125 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: dtype="float16", ) with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8), - W_1.data[ - T.ramp( + W_shared[T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8)] = W_1[ + T.ramp( + ( ( - ( - ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) - ) - + (tz * 256) - ) - + (tx * 8) - ), - 1, - 8, + ((((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + (ty * 512)) + + (tz * 256) + ) + + (tx * 8) ), - T.broadcast(True, 8), - ], + 1, + 8, + ), T.broadcast(True, 8), - ) + ] with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8), - W_1.data[ - T.ramp( + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8)] = W_1[ + T.ramp( + ( ( ( ( - ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) - ) - + (tz * 256) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (tx * 8) + + (tz * 256) ) - + 8192 - ), - 1, - 8, + + (tx * 8) + ) + + 8192 ), - T.broadcast(True, 8), - ], + 1, + 8, + ), T.broadcast(True, 8), - ) + ] with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8), - W_1.data[ - T.ramp( + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8)] = W_1[ + T.ramp( + ( ( ( ( - ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) - ) - + (tz * 256) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (tx * 8) + + (tz * 256) ) - + 131072 - ), - 1, - 8, + + (tx * 8) + ) + + 131072 ), - T.broadcast(True, 8), - ], + 1, + 8, + ), T.broadcast(True, 8), - ) + ] with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8), - W_1.data[ - T.ramp( + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8)] = W_1[ + T.ramp( + ( ( ( ( - ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) - ) - + (tz * 256) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (tx * 8) + + (tz * 256) ) - + 139264 - ), - 1, - 8, + + (tx * 8) + ) + + 139264 ), - T.broadcast(True, 8), - ], + 1, + 8, + ), T.broadcast(True, 8), - ) + ] with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8), - W_1.data[ - T.ramp( + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8)] = W_1[ + T.ramp( + ( ( ( ( - ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) - ) - + (tz * 256) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (tx * 8) + + (tz * 256) ) - + 262144 - ), - 1, - 8, + + (tx * 8) + ) + + 262144 ), - T.broadcast(True, 8), - ], + 1, + 8, + ), T.broadcast(True, 8), - ) + ] with T.launch_thread(tx, 32): - T.store( - W_shared, - T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8), - W_1.data[ - T.ramp( + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8)] = W_1[ + T.ramp( + ( ( ( ( - ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) - ) - + (tz * 256) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (tx * 8) + + (tz * 256) ) - + 270336 - ), - 1, - 8, + + (tx * 8) + ) + + 270336 ), - T.broadcast(True, 8), - ], + 1, + 8, + ), T.broadcast(True, 8), - ) + ] for ic_inner in T.serial(0, 2): for kw in T.serial(0, 3): T.evaluate( @@ -2559,9 +2410,9 @@ def vthread_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i2, 2) B = T.allocate([16], "float32", "local") for j in range(16): - B[j] = A.data[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) for j in range(16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * T.float32(2) + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * T.float32(2) def test_vthread(): @@ -2840,7 +2691,7 @@ def test_rank0_buffers(): def rank0_block(a: T.handle) -> None: A = T.match_buffer(a, (), "float32") B = T.alloc_buffer((), "float32") - T.store(B.data, 0, A.data[0]) + B[0] = A.data[0] with T.block("update") as []: T.reads([A[()]]) @@ -2984,12 +2835,12 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29.data[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")), True) + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): - T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16"), True) + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") # fmt: on @@ -3009,7 +2860,7 @@ def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_x, 128) reduce_temp0 = T.allocate([1], "float32", "local") with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A.data[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) @T.prim_func @@ -3021,7 +2872,7 @@ def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_x, 128) reduce_temp0 = T.allocate([1], "float32", "local") with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A.data[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) @T.prim_func From 09859a6982a621622813d0a3d04578ed05a59d05 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 20 Jan 2022 09:25:58 -0600 Subject: [PATCH 038/177] Added LOG(FATAL) in constructor of Store/Load nodes. --- src/tir/ir/expr.cc | 2 ++ src/tir/ir/stmt.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index afe24b73b80e..8f904d5bd8f7 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -626,6 +626,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Load Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) { + LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint + << ". Use BufferStore instead."; ICHECK(buffer_var.defined()); ICHECK(predicate.defined()); ICHECK(index.defined()); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 078561c447ad..484f8aa851f5 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -241,6 +241,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Store Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { + LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint + << ". Use BufferStore instead."; ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); From 8b5aab4e2de1a81e45839c6d6f4aad48d8f6e8e4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 20 Jan 2022 09:26:29 -0600 Subject: [PATCH 039/177] Updated tvmscript parser to report error for Store/Load nodes. --- python/tvm/script/parser.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index dca366bf4269..5939e08be852 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -630,13 +630,8 @@ def transform_SubscriptAssign(self, node): f"Store is only allowed with one index, but {len(indexes)} were provided.", node.params[1].span, ) - # Store - return tvm.tir.Store( - symbol, - tvm.runtime.convert(rhs, span=rhs_span), - indexes[0], - tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)), - span=tvm_span_from_synr(node.span), + self.report_error( + "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span ) def transform_Assert(self, node): @@ -950,15 +945,8 @@ def transform_Subscript(self, node): node.span, ) - return call_with_error_reporting( - self.report_error, - node.span, - tvm.tir.Load, - "float32", - symbol, - index, - True, - span=tvm_span_from_synr(node.span), + self.report_error( + "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span ) elif isinstance(symbol, tvm.tir.Buffer): return BufferSlice( From bc1b95752cab9ddb884360fec742b51eaa4f9f82 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 25 Jan 2022 09:47:45 -0600 Subject: [PATCH 040/177] [TVMScript] Added T.preflattened_buffer stmt Used to specify `PrimFunc::preflattened_buffer_map`. Takes an argument of the postflattened buffer, so that it will work for both simple declarations and `T.match_buffer` statements without needing to introduce a param handle. All other arguments are identical to `T.match_buffer.` --- python/tvm/script/tir/special_stmt.py | 54 +++++++++++++++++++++++++++ src/printer/tvmscript_printer.cc | 24 ++++++++++++ 2 files changed, 78 insertions(+) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index a513fd087c4b..05497b975127 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -870,6 +870,60 @@ def func_attr(dict_attr, span): super().__init__(func_attr, def_symbol=False) +@register +class PreflattenedBufferMap(SpecialStmt): + """Special Stmt for declaring the PrimFunc::preflattened_buffer_map + + Example + ------- + .. code-block:: python + T.preflattened_buffer_map({}) + """ + + def __init__(self): + def preflattened_buffer( + postflattened, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + span=None, + ): + + param = None + for key, value in self.context.func_buffer_map.items(): + if value.same_as(postflattened): + param = key + + assert ( + param is not None + ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." + + buffer_name: str = f"{postflattened.name}_preflatten" + preflattened = tvm.tir.decl_buffer( + shape, + dtype, + buffer_name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + span=span, + ) + + self.context.func_preflattened_buffer_map[param] = preflattened + + super().__init__(preflattened_buffer, def_symbol=False) + + @register class TargetAttrValue(SpecialStmt): """Special Stmt for target attr value. diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index c4b48e98de73..b2d3922a0bb4 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1434,9 +1434,30 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { if (simple_buf.count(buf)) continue; buf_not_in_headers_.insert(buf.get()); body << Print(buf) << " = " << tir_prefix_ << ".match_buffer("; + ICHECK(memo_buf_decl_.count(buf)); body << Print((*it).first) << ", " << memo_buf_decl_[buf]; body << ")" << Doc::NewLine(); } + // print preflattened buffer map + for (const auto& param : op->params) { + auto pf_buf_it = op->preflattened_buffer_map.find(param); + if (pf_buf_it != op->preflattened_buffer_map.end()) { + const Buffer& preflattened = (*pf_buf_it).second; + + auto buf_it = op->buffer_map.find(param); + ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name + << " with no corresponding post-flatten buffer."; + const Buffer& postflattened = (*buf_it).second; + + // Call Print() without assigning in order to fill memo_buf_decl_. + Print(preflattened); + buf_not_in_headers_.insert(preflattened.get()); + ICHECK(memo_buf_decl_.count(preflattened)); + + body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", " + << memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine(); + } + } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && @@ -1464,6 +1485,9 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { header_attr << PrintSep(attrs, Doc::Text(", ")) << "})"; } // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate) + + // TODO: Now that T.allocate returns a buffer object, it shouldn't + // have a buffer_decl anymore. Doc header_buf; std::vector bufs; for (const auto& it : memo_buf_) { From 6c111fc3afb9c841ff36a224b96f244e0ca2ad55 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 25 Jan 2022 13:37:59 -0600 Subject: [PATCH 041/177] [TVMScript] Updated TVMscript for BufferLoad/BufferStore - Use `T.preflattened_buffer` calls in TVMScript to represent `PrimFunc::preflattened_buffer_map`. - Remove `T.buffer_decl` for return value of `T.allocate`, now that `T.allocate` returns a buffer. - For buffer access as a different type, make a `T.buffer_decl` for those accesses. --- src/printer/tvmscript_printer.cc | 3 --- src/tir/ir/buffer.cc | 14 +++++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index b2d3922a0bb4..7931b1cddea5 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1485,9 +1485,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { header_attr << PrintSep(attrs, Doc::Text(", ")) << "})"; } // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate) - - // TODO: Now that T.allocate returns a buffer object, it shouldn't - // have a buffer_decl anymore. Doc header_buf; std::vector bufs; for (const auto& it : memo_buf_) { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 2ec2f49f0c69..f192d6bd11c9 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -515,9 +515,17 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array if (storage_dtype == DataType::Bool()) { storage_dtype = DataType::Int(8); } - ICHECK(IsPointerType(data->type_annotation, storage_dtype)) - << "Buffer data field expect to have the right pointer type annotation" - << " annotation=" << data->type_annotation << ", storage_dtype=" << storage_dtype; + // The buffer dtype may differ from the dtype of the underlying + // allocation, such as a single allocation that backs multiple + // tensors without a common datatype. Therefore, we check that the + // data pointer is a pointer, but not the exact type of the + // pointed-to values. + ICHECK(data->type_annotation.defined()) + << "Variable " << data->name_hint << " is missing a type annotation."; + ICHECK(data->type_annotation.as()) + << "Variable " << data->name_hint << " is not a pointer."; + ICHECK(data->type_annotation.as()->element_type.as()) + << "Variable " << data->name_hint << " does not point to a primitive."; auto n = make_object(); n->data = std::move(data); From fb1f9fb5ccf9fa3c43f6b36817f81875e2607055 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 19 Jan 2022 12:09:08 -0600 Subject: [PATCH 042/177] Updated test_tvmscript_roundtrip.py for BufferLoad/BufferStore. --- python/tvm/script/parser.py | 38 +- .../unittest/test_tvmscript_roundtrip.py | 4822 +++++++++-------- 2 files changed, 2481 insertions(+), 2379 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 5939e08be852..8beff4eb27f4 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -553,7 +553,11 @@ def transform_Assign(self, node): if isinstance(node.rhs, ast.Call): # Pattern 1 & Pattern 4 - func = self.transform(node.rhs.func_name) + if isinstance(node.rhs.func_name, ast.Op): + func = None + else: + func = self.transform(node.rhs.func_name) + if isinstance(func, WithScopeHandler): if not func.concise_scope or not func.def_symbol: self.report_error( @@ -634,6 +638,25 @@ def transform_SubscriptAssign(self, node): "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span ) + def transform_AttrAssign(self, node): + """Visitor for statements of the form :code:`x.y = 2`.""" + obj = self.transform(node.params[0]) + field = node.params[1] + value = self.transform(node.params[2]) + + if not hasattr(obj, field.name): + self.error(f"Field {field.name} does not exist", field.span) + + var = getattr(obj, field.name) + + if not isinstance(var, tvm.tir.Var): + self.error( + f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span + ) + + body = self.parse_body(node) + return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) + def transform_Assert(self, node): """Assert visitor @@ -859,13 +882,16 @@ def f(): """ # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign. if isinstance(node.call.func_name, ast.Op): - if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign: - self.report_error( - "Binary and unary operators are not allowed as a statement", node.span - ) - else: + if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign: return self.transform_SubscriptAssign(node.call) + if node.call.func_name.name == ast.BuiltinOp.AttrAssign: + return self.transform_AttrAssign(node.call) + + self.report_error( + "Binary and unary operators are not allowed as a statement", node.span + ) + # handle a regular function call func = self.transform(node.call.func_name) arg_list = self.parse_arg_list(func, node.call) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index df397fb81c73..a6f22adc0858 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -19,749 +19,684 @@ import pytest import tvm +import tvm.testing from tvm import tir from tvm.script import tir as T -@tvm.script.ir_module -class Module1: - @T.prim_func - def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - # buffer definition - C_global = T.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) - packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) - A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - # body - T.realize(packedB[0:32, 0:1024, 0:32], "") - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - for z in T.vectorized(0, 32): - packedB[x, y, z] = B_1[y, ((x * 32) + z)] - T.realize(C_1[0:1024, 0:1024], "") - for x_outer in T.parallel(0, 32): - for y_outer in T.serial(0, 32): - T.realize( - C_global[ - (x_outer * 32) : ((x_outer * 32) + 32), - (y_outer * 32) : ((y_outer * 32) + 32), - ], - "global", - ) - for x_c_init in T.serial(0, 32): - for y_c_init in T.vectorized(0, 32): +def opt_gemm_normalize(): + @tvm.script.ir_module + class Module: + @T.prim_func + def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # buffer definition + C_global = T.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) + packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + # body + T.realize(packedB[0:32, 0:1024, 0:32], "") + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + for z in T.vectorized(0, 32): + packedB[x, y, z] = B_1[y, ((x * 32) + z)] + T.realize(C_1[0:1024, 0:1024], "") + for x_outer in T.parallel(0, 32): + for y_outer in T.serial(0, 32): + T.realize( C_global[ - (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) - ] = T.float32(0) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - for k_inner in T.unroll(0, 4): - for y_c in T.vectorized(0, 32): - C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] = C_global[ - (x_c + (x_outer * 32)), (y_c + (y_outer * 32)) - ] + ( - A_1[(x_c + (x_outer * 32)), (k_inner + (k_outer * 4))] - * packedB[ - T.floordiv((y_c + (y_outer * 32)), 32), - (k_inner + (k_outer * 4)), - T.floormod((y_c + (y_outer * 32)), 32), - ] - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[ - (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32)) - ] - - -def test_opt_gemm_normalize(): - mod = Module1 - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@tvm.script.ir_module -class Module2: - @T.prim_func - def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - # body - packedB = T.allocate([32768], "float32", "global") - for x in T.parallel(0, 32): - for y in T.serial(0, 1024): - packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[ - T.ramp(((y * 1024) + (x * 32)), 1, 32) - ] - for x_outer in T.parallel(0, 32): - C_global = T.allocate([1024], "float32", "global") - for y_outer in T.serial(0, 32): - for x_c_init in T.serial(0, 32): - C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ - T.ramp((x_c * 32), 1, 32) - ] + ( - T.broadcast( - A_1[ - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), - ], - 32, - ) - * packedB[T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32)] - ) - C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ - T.ramp((x_c * 32), 1, 32) - ] + ( - T.broadcast( - A_1[ - ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 1), - ], - 32, - ) - * packedB[T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32)] - ) - C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ - T.ramp((x_c * 32), 1, 32) - ] + ( - T.broadcast( - A_1[ - ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 2), - ], - 32, - ) - * packedB[T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32)] - ) - C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ - T.ramp((x_c * 32), 1, 32) - ] + ( - T.broadcast( - A_1[ - ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 3), - ], - 32, - ) - * packedB[T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32)] - ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C_1[ - ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner) - ] = C_global[((x_inner * 32) + y_inner)] - - -def test_opt_gemm_lower(): - mod = Module2 - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@tvm.script.ir_module -class Module3: - @T.prim_func - def mmult( - args: T.handle, - arg_type_ids: T.handle, - num_args: T.int32, - out_ret_value: T.handle, - out_ret_tcode: T.handle, - ) -> T.int32: - # function attr dict - T.func_attr( - { - "tir.noalias": True, - "global_symbol": "mmult", - "tir.is_entry_func": True, - "calling_conv": 1, - } - ) - # var definition - C_global = T.buffer_var("float32", "global") - packedB = T.buffer_var("float32", "global") - # body - assert num_args == 3, "mmult: num_args should be 3" - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = arg_type_ids[0] - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = arg_type_ids[1] - arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = arg_type_ids[2] - A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - T.attr(A, "storage_alignment", 128) - arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") - arg0_strides: T.handle = T.tvm_struct_get(arg0, 0, 3, dtype="handle") - dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - B: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - T.attr(B, "storage_alignment", 128) - arg1_shape: T.handle = T.tvm_struct_get(arg1, 0, 2, dtype="handle") - arg1_strides: T.handle = T.tvm_struct_get(arg1, 0, 3, dtype="handle") - C: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") - T.attr(C, "storage_alignment", 128) - arg2_shape: T.handle = T.tvm_struct_get(arg2, 0, 2, dtype="handle") - arg2_strides: T.handle = T.tvm_struct_get(arg2, 0, 3, dtype="handle") - assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( - arg0_code == 4 - ), "mmult: Expect arg[0] to be pointer" - assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( - arg1_code == 4 - ), "mmult: Expect arg[1] to be pointer" - assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( - arg2_code == 4 - ), "mmult: Expect arg[2] to be pointer" - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 2" - assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 2" - assert ( - (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg0.dtype is expected to be float32" - assert 1024 == T.cast( - arg0_shape[0], "int32" - ), "Argument arg0.shape[0] has an unsatisfied constraint" - assert 1024 == T.cast( - arg0_shape[1], "int32" - ), "Argument arg0.shape[1] has an unsatisfied constraint" - if not (T.isnullptr(arg0_strides, dtype="bool")): - assert (1 == T.cast(arg0_strides[1], "int32")) and ( - 1024 == T.cast(arg0_strides[0], "int32") - ), "arg0.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg0, 0, 8, dtype="uint64" - ), "Argument arg0.byte_offset has an unsatisfied constraint" - assert 1 == T.tvm_struct_get( - arg0, 0, 10, dtype="int32" - ), "Argument arg0.device_type has an unsatisfied constraint" - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2" - assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2" - assert ( - (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg1.dtype is expected to be float32" - assert 1024 == T.cast( - arg1_shape[0], "int32" - ), "Argument arg1.shape[0] has an unsatisfied constraint" - assert 1024 == T.cast( - arg1_shape[1], "int32" - ), "Argument arg1.shape[1] has an unsatisfied constraint" - if not (T.isnullptr(arg1_strides, dtype="bool")): - assert (1 == T.cast(arg1_strides[1], "int32")) and ( - 1024 == T.cast(arg1_strides[0], "int32") - ), "arg1.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg1, 0, 8, dtype="uint64" - ), "Argument arg1.byte_offset has an unsatisfied constraint" - assert 1 == T.tvm_struct_get( - arg1, 0, 10, dtype="int32" - ), "Argument arg1.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg1, 0, 9, dtype="int32" - ), "Argument arg1.device_id has an unsatisfied constraint" - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 2" - assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 2" - assert ( - (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg2.dtype is expected to be float32" - assert 1024 == T.cast( - arg2_shape[0], "int32" - ), "Argument arg2.shape[0] has an unsatisfied constraint" - assert 1024 == T.cast( - arg2_shape[1], "int32" - ), "Argument arg2.shape[1] has an unsatisfied constraint" - if not (T.isnullptr(arg2_strides, dtype="bool")): - assert (1 == T.cast(arg2_strides[1], "int32")) and ( - 1024 == T.cast(arg2_strides[0], "int32") - ), "arg2.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg2, 0, 8, dtype="uint64" - ), "Argument arg2.byte_offset has an unsatisfied constraint" - assert 1 == T.tvm_struct_get( - arg2, 0, 10, dtype="int32" - ), "Argument arg2.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg2, 0, 9, dtype="int32" - ), "Argument arg2.device_id has an unsatisfied constraint" - T.attr(0, "compute_scope", "mmult_compute_") - T.attr(packedB, "storage_scope", "global") - T.attr(packedB, "storage_alignment", 128) - with T.let( - packedB, - T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), - ): - if T.isnullptr(packedB, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) + (x_outer * 32) : ((x_outer * 32) + 32), + (y_outer * 32) : ((y_outer * 32) + 32), + ], + "global", + ) + for x_c_init in T.serial(0, 32): + for y_c_init in T.vectorized(0, 32): + C_global[ + (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) + ] = T.float32(0) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + for k_inner in T.unroll(0, 4): + for y_c in T.vectorized(0, 32): + C_global[ + (x_c + (x_outer * 32)), (y_c + (y_outer * 32)) + ] = C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] + ( + A_1[(x_c + (x_outer * 32)), (k_inner + (k_outer * 4))] + * packedB[ + T.floordiv((y_c + (y_outer * 32)), 32), + (k_inner + (k_outer * 4)), + T.floormod((y_c + (y_outer * 32)), 32), + ] + ) + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): + C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[ + (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32)) + ] + + return Module + + +def opt_gemm_lower(): + @tvm.script.ir_module + class Module: + @T.prim_func + def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + # body + packedB = T.allocate([32768], "float32", "global") for x in T.parallel(0, 32): for y in T.serial(0, 1024): - packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B[ - T.ramp(((y * 1024) + (x * 32)), 1, 32), - T.broadcast(True, 32), + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[ + T.ramp(((y * 1024) + (x * 32)), 1, 32) ] for x_outer in T.parallel(0, 32): - T.attr(C_global, "storage_scope", "global") - T.attr(C_global, "storage_alignment", 128) - with T.let( - C_global, - T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4096), 2, 32, dtype="handle"), - ): - if T.isnullptr(C_global, dtype="bool"): - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - for y_outer in T.serial(0, 32): - for x_c_init in T.serial(0, 32): - C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) - for k_outer in T.serial(0, 256): - for x_c in T.serial(0, 32): - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), - ], - 32, - ), - packedB[ - T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - T.broadcast(True, 32), + C_global = T.allocate([1024], "float32", "global") + for y_outer in T.serial(0, 32): + for x_c_init in T.serial(0, 32): + C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast(T.float32(0), 32) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), - ], - dtype="float32x32", + 32, ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 1 - ), - ], - 32, - ), - packedB[ - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32), - T.broadcast(True, 32), - ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), + * packedB[T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32)] + ) + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 1), ], - dtype="float32x32", + 32, ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 2 - ), - ], - 32, - ), - packedB[ - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32), - T.broadcast(True, 32), - ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), + * packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32) + ] + ) + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 2), ], - dtype="float32x32", + 32, ) - C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( - T.uint32(97), - T.uint32(3), - T.broadcast( - A[ - ( - (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) - + 3 - ), - ], - 32, - ), - packedB[ - T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32), - T.broadcast(True, 32), - ], - C_global[ - T.ramp((x_c * 32), 1, 32), - T.broadcast(True, 32), + * packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32) + ] + ) + C_global[T.ramp((x_c * 32), 1, 32)] = C_global[ + T.ramp((x_c * 32), 1, 32) + ] + ( + T.broadcast( + A_1[ + ((((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)) + 3), ], - dtype="float32x32", + 32, ) - for x_inner in T.serial(0, 32): - for y_inner in T.serial(0, 32): - C[ - ( - (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) - + y_inner - ) - ] = C_global[((x_inner * 32) + y_inner)] - if T.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - if T.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0: - T.evaluate(T.tvm_throw_last_error(dtype="int32")) - - -def test_opt_gemm_mod_host(): - mod = Module3 - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@T.prim_func -def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) - # var definition - bx = T.env_thread("blockIdx.x") - by = T.env_thread("blockIdx.y") - bz = T.env_thread("blockIdx.z") - tx = T.env_thread("threadIdx.x") - ty = T.env_thread("threadIdx.y") - tz = T.env_thread("threadIdx.z") - # buffer definition - Apad_shared = T.buffer_decl( - [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Apad_shared_wmma_matrix_a = T.buffer_decl( - [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - BA = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 - ) - BB = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 - ) - BC = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) - Conv_wmma_accumulator = T.buffer_decl( - [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) - W_shared = T.buffer_decl( - [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_shared_wmma_matrix_b = T.buffer_decl( - [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - buffer = T.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) - buffer_1 = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 - ) - buffer_2 = T.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) - buffer_3 = T.buffer_decl( - [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 - ) - buffer_4 = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) - buffer_5 = T.buffer_decl([16, 16], align=32, offset_factor=256) - A_1 = T.match_buffer( - A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_1 = T.match_buffer( - W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Conv_1 = T.match_buffer( - Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) - # body - T.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") - T.launch_thread(bz, 196) - T.launch_thread(bx, 2) - T.launch_thread(by, 4) - T.launch_thread(ty, 4) - T.launch_thread(tz, 2) - T.realize( - Conv_wmma_accumulator[ - ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), - T.floordiv(bz, 14) : (T.floordiv(bz, 14) + 1), - T.floormod(bz, 14) : (T.floormod(bz, 14) + 1), - ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), - 0:16, - 0:16, - ], - "wmma.accumulator", - ) - for n_c_init in T.serial(0, 2): - for o_c_init in T.serial(0, 4): - T.attr( - [BC, Conv_wmma_accumulator], - "buffer_bind_scope", - T.tvm_tuple( - (n_c_init + ((bx * 8) + (ty * 2))), - 1, - T.floordiv(bz, 14), - 1, - T.floormod(bz, 14), - 1, - (o_c_init + ((by * 8) + (tz * 4))), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_fill_fragment( - BC.data, - 16, - 16, - 16, - T.floordiv(BC.elem_offset, 256), - T.float32(0), - dtype="handle", - ) - ) - for ic_outer in T.serial(0, 8): - for kh in T.serial(0, 3): - T.realize( - Apad_shared[ - (bx * 8) : ((bx * 8) + 8), - (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), - T.floormod(bz, 14) : (T.floormod(bz, 14) + 3), - (ic_outer * 2) : ((ic_outer * 2) + 2), - 0:16, - 0:16, - ], - "shared", - ) - for ax2 in T.serial(0, 3): - for ax3 in T.serial(0, 2): - for ax4_ax5_fused_outer in T.serial(0, 8): - T.launch_thread(tx, 32) - Apad_shared[ - ((tz + (ty * 2)) + (bx * 8)), - (T.floordiv(bz, 14) + kh), - (ax2 + T.floormod(bz, 14)), - (ax3 + (ic_outer * 2)), - T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), - T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), - ] = T.if_then_else( - ( + * packedB[ + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32) + ] + ) + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): + C_1[ ( - ( - ((T.floordiv(bz, 14) + kh) >= 1) - and (((T.floordiv(bz, 14) + kh) - 1) < 14) - ) - and ((ax2 + T.floormod(bz, 14)) >= 1) + (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + + y_inner ) - and (((ax2 + T.floormod(bz, 14)) - 1) < 14) - ), - A_1[ - ((tz + (ty * 2)) + (bx * 8)), - ((T.floordiv(bz, 14) + kh) - 1), - ((ax2 + T.floormod(bz, 14)) - 1), - (ax3 + (ic_outer * 2)), - T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), - T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), - ], - T.float16(0), - dtype="float16", - ) - T.realize( - W_shared[ - kh : (kh + 1), - 0:3, - (ic_outer * 2) : ((ic_outer * 2) + 2), - (by * 8) : ((by * 8) + 8), - 0:16, - 0:16, - ], - "shared", + ] = C_global[((x_inner * 32) + y_inner)] + + return Module + + +def opt_gemm_mod_host(): + @tvm.script.ir_module + class Module: + @T.prim_func + def mmult( + args: T.handle, + arg_type_ids: T.handle, + num_args: T.int32, + out_ret_value: T.handle, + out_ret_tcode: T.handle, + ) -> T.int32: + # function attr dict + T.func_attr( + { + "tir.noalias": True, + "global_symbol": "mmult", + "tir.is_entry_func": True, + "calling_conv": 1, + } ) - for ax1 in T.serial(0, 3): - for ax2_1 in T.serial(0, 2): - T.launch_thread(tx, 32) - for ax4_ax5_fused_inner in T.vectorized(0, 8): - W_shared[ - kh, - ax1, - (ax2_1 + (ic_outer * 2)), - ((tz + (ty * 2)) + (by * 8)), - T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), - T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), - ] = W_1[ - kh, - ax1, - (ax2_1 + (ic_outer * 2)), - ((tz + (ty * 2)) + (by * 8)), - T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), - T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + # buffer definition + buf_type_ids = T.match_buffer(arg_type_ids, [3], dtype="int32") + + packedB = T.buffer_decl([32768], dtype="float32") + C_global = T.buffer_decl([1024], dtype="float32") + # var definition + # C_global = T.buffer_var("float32", "global") + # packedB = T.buffer_var("float32", "global") + # body + assert num_args == 3, "mmult: num_args should be 3" + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: T.int32 = buf_type_ids[0] + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: T.int32 = buf_type_ids[1] + arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: T.int32 = buf_type_ids[2] + + A_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + T.attr(A_data, "storage_alignment", 128) + A: T.Buffer = T.buffer_decl([1024, 1024], dtype="int32", data=A_data) + buf0_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + buf0_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf0_shape_data) + buf0_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + buf0_strides: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf0_strides_data) + + dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") + + B_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + T.attr(B_data, "storage_alignment", 128) + B: T.Buffer = T.buffer_decl([1024, 1024], dtype="int32", data=B_data) + buf1_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + buf1_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf1_shape_data) + buf1_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + buf1_strides: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf1_strides_data) + + C_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + T.attr(C_data, "storage_alignment", 128) + C: T.Buffer = T.buffer_decl([1024, 1024], dtype="int32", data=C_data) + buf2_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + buf2_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf2_shape_data) + buf2_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + buf2_strides: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf2_strides_data) + + assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( + arg0_code == 4 + ), "mmult: Expect arg[0] to be pointer" + assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( + arg1_code == 4 + ), "mmult: Expect arg[1] to be pointer" + assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( + arg2_code == 4 + ), "mmult: Expect arg[2] to be pointer" + assert 2 == T.tvm_struct_get( + arg0, 0, 4, dtype="int32" + ), "arg0.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get( + arg0, 0, 4, dtype="int32" + ), "arg0.ndim is expected to equal 2" + assert ( + (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg0.dtype is expected to be float32" + assert 1024 == T.cast( + buf0_shape[0], "int32" + ), "Argument arg0.shape[0] has an unsatisfied constraint" + assert 1024 == T.cast( + buf0_shape[1], "int32" + ), "Argument arg0.shape[1] has an unsatisfied constraint" + if not (T.isnullptr(buf0_strides.data, dtype="bool")): + assert (1 == T.cast(buf0_strides[1], "int32")) and ( + 1024 == T.cast(buf0_strides[0], "int32") + ), "arg0.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg0, 0, 8, dtype="uint64" + ), "Argument arg0.byte_offset has an unsatisfied constraint" + assert 1 == T.tvm_struct_get( + arg0, 0, 10, dtype="int32" + ), "Argument arg0.device_type has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg1, 0, 4, dtype="int32" + ), "arg1.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get( + arg1, 0, 4, dtype="int32" + ), "arg1.ndim is expected to equal 2" + assert ( + (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg1.dtype is expected to be float32" + assert 1024 == T.cast( + buf1_shape[0], "int32" + ), "Argument arg1.shape[0] has an unsatisfied constraint" + assert 1024 == T.cast( + buf1_shape[1], "int32" + ), "Argument arg1.shape[1] has an unsatisfied constraint" + if not (T.isnullptr(buf1_strides.data, dtype="bool")): + assert (1 == T.cast(buf1_strides[1], "int32")) and ( + 1024 == T.cast(buf1_strides[0], "int32") + ), "arg1.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg1, 0, 8, dtype="uint64" + ), "Argument arg1.byte_offset has an unsatisfied constraint" + assert 1 == T.tvm_struct_get( + arg1, 0, 10, dtype="int32" + ), "Argument arg1.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg1, 0, 9, dtype="int32" + ), "Argument arg1.device_id has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg2, 0, 4, dtype="int32" + ), "arg2.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get( + arg2, 0, 4, dtype="int32" + ), "arg2.ndim is expected to equal 2" + assert ( + (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg2.dtype is expected to be float32" + assert 1024 == T.cast( + buf2_shape[0], "int32" + ), "Argument arg2.shape[0] has an unsatisfied constraint" + assert 1024 == T.cast( + buf2_shape[1], "int32" + ), "Argument arg2.shape[1] has an unsatisfied constraint" + if not (T.isnullptr(buf2_strides.data, dtype="bool")): + assert (1 == T.cast(buf2_strides[1], "int32")) and ( + 1024 == T.cast(buf2_strides[0], "int32") + ), "arg2.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg2, 0, 8, dtype="uint64" + ), "Argument arg2.byte_offset has an unsatisfied constraint" + assert 1 == T.tvm_struct_get( + arg2, 0, 10, dtype="int32" + ), "Argument arg2.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg2, 0, 9, dtype="int32" + ), "Argument arg2.device_id has an unsatisfied constraint" + T.attr(0, "compute_scope", "mmult_compute_") + T.attr(packedB.data, "storage_scope", "global") + T.attr(packedB.data, "storage_alignment", 128) + with T.let( + packedB.data, + T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), + ): + if T.isnullptr(packedB.data, dtype="bool"): + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B[ + T.ramp(((y * 1024) + (x * 32)), 1, 32) ] - for ic_inner in T.serial(0, 2): - for kw in T.serial(0, 3): - T.realize( - Apad_shared_wmma_matrix_a[ - ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), - (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), - (kw + T.floormod(bz, 14)) : ((kw + T.floormod(bz, 14)) + 1), - ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), - 0:16, - 0:16, - ], - "wmma.matrix_a", + for x_outer in T.parallel(0, 32): + T.attr(C_global.data, "storage_scope", "global") + T.attr(C_global.data, "storage_alignment", 128) + with T.let( + C_global.data, + T.TVMBackendAllocWorkspace( + 1, dev_id, T.uint64(4096), 2, 32, dtype="handle" + ), + ): + if T.isnullptr(C_global.data, dtype="bool"): + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for y_outer in T.serial(0, 32): + for x_c_init in T.serial(0, 32): + C_global[T.ramp((x_c_init * 32), 1, 32)] = T.broadcast( + T.float32(0), 32 + ) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + ((x_outer * 32768) + (x_c * 1024)) + + (k_outer * 4) + ), + ], + 32, + ), + packedB[ + T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32) + ], + C_global[T.ramp((x_c * 32), 1, 32)], + dtype="float32x32", + ) + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + ( + ((x_outer * 32768) + (x_c * 1024)) + + (k_outer * 4) + ) + + 1 + ), + ], + 32, + ), + packedB[ + T.ramp( + (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 + ) + ], + C_global[T.ramp((x_c * 32), 1, 32)], + dtype="float32x32", + ) + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + ( + ((x_outer * 32768) + (x_c * 1024)) + + (k_outer * 4) + ) + + 2 + ), + ], + 32, + ), + packedB[ + T.ramp( + (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 + ) + ], + C_global[T.ramp((x_c * 32), 1, 32)], + dtype="float32x32", + ) + C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + A[ + ( + ( + ((x_outer * 32768) + (x_c * 1024)) + + (k_outer * 4) + ) + + 3 + ), + ], + 32, + ), + packedB[ + T.ramp( + (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 + ) + ], + C_global[T.ramp((x_c * 32), 1, 32)], + dtype="float32x32", + ) + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): + C[ + ( + ( + ((x_outer * 32768) + (x_inner * 1024)) + + (y_outer * 32) + ) + + y_inner + ) + ] = C_global[((x_inner * 32) + y_inner)] + if T.TVMBackendFreeWorkspace(1, dev_id, C_global.data, dtype="int32") != 0: + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + if T.TVMBackendFreeWorkspace(1, dev_id, packedB.data, dtype="int32") != 0: + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + + return Module + + +def opt_conv_tensorcore_normalize(): + @T.prim_func + def func(A: T.handle, W: T.handle, Conv: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # var definition + bx = T.env_thread("blockIdx.x") + by = T.env_thread("blockIdx.y") + bz = T.env_thread("blockIdx.z") + tx = T.env_thread("threadIdx.x") + ty = T.env_thread("threadIdx.y") + tz = T.env_thread("threadIdx.z") + # buffer definition + Apad_shared = T.buffer_decl( + [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + Apad_shared_wmma_matrix_a = T.buffer_decl( + [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + BA = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 + ) + BB = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 + ) + BC = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + Conv_wmma_accumulator = T.buffer_decl( + [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 + ) + W_shared = T.buffer_decl( + [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + W_shared_wmma_matrix_b = T.buffer_decl( + [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + buffer = T.buffer_decl( + [16, 16], dtype="float16", scope="shared", align=32, offset_factor=256 + ) + buffer_1 = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 + ) + buffer_2 = T.buffer_decl( + [16, 16], dtype="float16", scope="shared", align=32, offset_factor=256 + ) + buffer_3 = T.buffer_decl( + [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 + ) + buffer_4 = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + buffer_5 = T.buffer_decl([16, 16], align=32, offset_factor=256) + A_1 = T.match_buffer( + A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + W_1 = T.match_buffer( + W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + Conv_1 = T.match_buffer( + Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 + ) + # body + T.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") + T.launch_thread(bz, 196) + T.launch_thread(bx, 2) + T.launch_thread(by, 4) + T.launch_thread(ty, 4) + T.launch_thread(tz, 2) + T.realize( + Conv_wmma_accumulator[ + ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), + T.floordiv(bz, 14) : (T.floordiv(bz, 14) + 1), + T.floormod(bz, 14) : (T.floormod(bz, 14) + 1), + ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), + 0:16, + 0:16, + ], + "wmma.accumulator", + ) + for n_c_init in T.serial(0, 2): + for o_c_init in T.serial(0, 4): + T.attr( + [BC, Conv_wmma_accumulator], + "buffer_bind_scope", + T.tvm_tuple( + (n_c_init + ((bx * 8) + (ty * 2))), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + (o_c_init + ((by * 8) + (tz * 4))), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.evaluate( + T.tvm_fill_fragment( + BC.data, + 16, + 16, + 16, + T.floordiv(BC.elem_offset, 256), + T.float32(0), + dtype="handle", ) - for ax0 in T.serial(0, 2): - T.attr( - [buffer, Apad_shared], - "buffer_bind_scope", - T.tvm_tuple( - (ax0 + ((bx * 8) + (ty * 2))), - 1, - (T.floordiv(bz, 14) + kh), - 1, - (kw + T.floormod(bz, 14)), - 1, - ((ic_outer * 2) + ic_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.attr( - [buffer_1, Apad_shared_wmma_matrix_a], - "buffer_bind_scope", - T.tvm_tuple( - (ax0 + ((bx * 8) + (ty * 2))), - 1, - (T.floordiv(bz, 14) + kh), - 1, - (kw + T.floormod(bz, 14)), - 1, - ((ic_outer * 2) + ic_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_load_matrix_sync( - buffer_1.data, - 16, - 16, - 16, - T.floordiv(buffer_1.elem_offset, 256), - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - buffer.data, - buffer.elem_offset, - 256, - 1, - dtype="handle", + ) + + for ic_outer in T.serial(0, 8): + for kh in T.serial(0, 3): + T.realize( + Apad_shared[ + (bx * 8) : ((bx * 8) + 8), + (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), + T.floormod(bz, 14) : (T.floormod(bz, 14) + 3), + (ic_outer * 2) : ((ic_outer * 2) + 2), + 0:16, + 0:16, + ], + "shared", + ) + for ax2 in T.serial(0, 3): + for ax3 in T.serial(0, 2): + for ax4_ax5_fused_outer in T.serial(0, 8): + T.launch_thread(tx, 32) + Apad_shared[ + ((tz + (ty * 2)) + (bx * 8)), + (T.floordiv(bz, 14) + kh), + (ax2 + T.floormod(bz, 14)), + (ax3 + (ic_outer * 2)), + T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), + ] = T.if_then_else( + ( + ( + ( + ((T.floordiv(bz, 14) + kh) >= 1) + and (((T.floordiv(bz, 14) + kh) - 1) < 14) + ) + and ((ax2 + T.floormod(bz, 14)) >= 1) + ) + and (((ax2 + T.floormod(bz, 14)) - 1) < 14) ), - 16, - "row_major", - dtype="handle", + A_1[ + ((tz + (ty * 2)) + (bx * 8)), + ((T.floordiv(bz, 14) + kh) - 1), + ((ax2 + T.floormod(bz, 14)) - 1), + (ax3 + (ic_outer * 2)), + T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), + ], + T.float16(0), + dtype="float16", ) - ) - T.realize( - W_shared_wmma_matrix_b[ - kh : (kh + 1), - kw : (kw + 1), - ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), - ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), - 0:16, - 0:16, - ], - "wmma.matrix_b", - ) - for ax3_1 in T.serial(0, 4): - T.attr( - [buffer_2, W_shared], - "buffer_bind_scope", - T.tvm_tuple( + T.realize( + W_shared[ + kh : (kh + 1), + 0:3, + (ic_outer * 2) : ((ic_outer * 2) + 2), + (by * 8) : ((by * 8) + 8), + 0:16, + 0:16, + ], + "shared", + ) + for ax1 in T.serial(0, 3): + for ax2_1 in T.serial(0, 2): + T.launch_thread(tx, 32) + for ax4_ax5_fused_inner in T.vectorized(0, 8): + W_shared[ kh, - 1, - kw, - 1, - ((ic_outer * 2) + ic_inner), - 1, - (ax3_1 + ((by * 8) + (tz * 4))), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.attr( - [buffer_3, W_shared_wmma_matrix_b], - "buffer_bind_scope", - T.tvm_tuple( + ax1, + (ax2_1 + (ic_outer * 2)), + ((tz + (ty * 2)) + (by * 8)), + T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + ] = W_1[ kh, - 1, - kw, - 1, - ((ic_outer * 2) + ic_inner), - 1, - (ax3_1 + ((by * 8) + (tz * 4))), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), + ax1, + (ax2_1 + (ic_outer * 2)), + ((tz + (ty * 2)) + (by * 8)), + T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + ] + for ic_inner in T.serial(0, 2): + for kw in T.serial(0, 3): + T.realize( + Apad_shared_wmma_matrix_a[ + ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), + (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), + (kw + T.floormod(bz, 14)) : ((kw + T.floormod(bz, 14)) + 1), + ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), + 0:16, + 0:16, + ], + "wmma.matrix_a", ) - T.evaluate( - T.tvm_load_matrix_sync( - buffer_3.data, - 16, - 16, - 16, - T.floordiv(buffer_3.elem_offset, 256), - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - buffer_2.data, - buffer_2.elem_offset, - 256, + for ax0 in T.serial(0, 2): + T.attr( + [buffer, Apad_shared], + "buffer_bind_scope", + T.tvm_tuple( + (ax0 + ((bx * 8) + (ty * 2))), + 1, + (T.floordiv(bz, 14) + kh), + 1, + (kw + T.floormod(bz, 14)), + 1, + ((ic_outer * 2) + ic_inner), 1, + 0, + 16, + 0, + 16, dtype="handle", ), - 16, - "row_major", - dtype="handle", ) - ) - for n_c in T.serial(0, 2): - for o_c in T.serial(0, 4): T.attr( - [BA, Apad_shared_wmma_matrix_a], + [buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", T.tvm_tuple( - (n_c + ((bx * 8) + (ty * 2))), + (ax0 + ((bx * 8) + (ty * 2))), 1, (T.floordiv(bz, 14) + kh), 1, - (T.floormod(bz, 14) + kw), + (kw + T.floormod(bz, 14)), 1, ((ic_outer * 2) + ic_inner), 1, @@ -772,8 +707,40 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N dtype="handle", ), ) + T.evaluate( + T.tvm_load_matrix_sync( + buffer_1.data, + 16, + 16, + 16, + T.floordiv(buffer_1.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + buffer.data, + buffer.elem_offset, + 256, + 1, + dtype="handle", + ), + 16, + "row_major", + dtype="handle", + ) + ) + T.realize( + W_shared_wmma_matrix_b[ + kh : (kh + 1), + kw : (kw + 1), + ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), + ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), + 0:16, + 0:16, + ], + "wmma.matrix_b", + ) + for ax3_1 in T.serial(0, 4): T.attr( - [BB, W_shared_wmma_matrix_b], + [buffer_2, W_shared], "buffer_bind_scope", T.tvm_tuple( kh, @@ -782,7 +749,7 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N 1, ((ic_outer * 2) + ic_inner), 1, - (o_c + ((by * 8) + (tz * 4))), + (ax3_1 + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -792,16 +759,16 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N ), ) T.attr( - [BC, Conv_wmma_accumulator], + [buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", T.tvm_tuple( - (n_c + ((bx * 8) + (ty * 2))), + kh, 1, - T.floordiv(bz, 14), + kw, 1, - T.floormod(bz, 14), + ((ic_outer * 2) + ic_inner), 1, - (o_c + ((by * 8) + (tz * 4))), + (ax3_1 + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -811,720 +778,853 @@ def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> N ), ) T.evaluate( - T.tvm_mma_sync( - BC.data, - T.floordiv(BC.elem_offset, 256), - BA.data, - T.floordiv(BA.elem_offset, 256), - BB.data, - T.floordiv(BB.elem_offset, 256), - BC.data, - T.floordiv(BC.elem_offset, 256), + T.tvm_load_matrix_sync( + buffer_3.data, + 16, + 16, + 16, + T.floordiv(buffer_3.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + buffer_2.data, + buffer_2.elem_offset, + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", ) ) - for n_inner in T.serial(0, 2): - for o_inner in T.serial(0, 4): - T.attr( - [buffer_4, Conv_wmma_accumulator], - "buffer_bind_scope", - T.tvm_tuple( - ((((bx * 4) + ty) * 2) + n_inner), - 1, - T.floordiv(bz, 14), - 1, - T.floormod(bz, 14), - 1, - ((((by * 2) + tz) * 4) + o_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.attr( - [buffer_5, Conv_1], - "buffer_bind_scope", - T.tvm_tuple( - ((((bx * 4) + ty) * 2) + n_inner), - 1, - T.floordiv(bz, 14), - 1, - T.floormod(bz, 14), - 1, - ((((by * 2) + tz) * 4) + o_inner), - 1, - 0, - 16, - 0, - 16, - dtype="handle", - ), - ) - T.evaluate( - T.tvm_store_matrix_sync( - buffer_4.data, - 16, - 16, - 16, - T.floordiv(buffer_4.elem_offset, 256), - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - buffer_5.data, - buffer_5.elem_offset, - 256, - 2, + for n_c in T.serial(0, 2): + for o_c in T.serial(0, 4): + T.attr( + [BA, Apad_shared_wmma_matrix_a], + "buffer_bind_scope", + T.tvm_tuple( + (n_c + ((bx * 8) + (ty * 2))), + 1, + (T.floordiv(bz, 14) + kh), + 1, + (T.floormod(bz, 14) + kw), + 1, + ((ic_outer * 2) + ic_inner), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.attr( + [BB, W_shared_wmma_matrix_b], + "buffer_bind_scope", + T.tvm_tuple( + kh, + 1, + kw, + 1, + ((ic_outer * 2) + ic_inner), + 1, + (o_c + ((by * 8) + (tz * 4))), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.attr( + [BC, Conv_wmma_accumulator], + "buffer_bind_scope", + T.tvm_tuple( + (n_c + ((bx * 8) + (ty * 2))), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + (o_c + ((by * 8) + (tz * 4))), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.evaluate( + T.tvm_mma_sync( + BC.data, + T.floordiv(BC.elem_offset, 256), + BA.data, + T.floordiv(BA.elem_offset, 256), + BB.data, + T.floordiv(BB.elem_offset, 256), + BC.data, + T.floordiv(BC.elem_offset, 256), + dtype="handle", + ) + ) + for n_inner in T.serial(0, 2): + for o_inner in T.serial(0, 4): + T.attr( + [buffer_4, Conv_wmma_accumulator], + "buffer_bind_scope", + T.tvm_tuple( + ((((bx * 4) + ty) * 2) + n_inner), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + ((((by * 2) + tz) * 4) + o_inner), + 1, + 0, + 16, + 0, + 16, dtype="handle", ), - 16, - "row_major", - dtype="handle", ) - ) + T.attr( + [buffer_5, Conv_1], + "buffer_bind_scope", + T.tvm_tuple( + ((((bx * 4) + ty) * 2) + n_inner), + 1, + T.floordiv(bz, 14), + 1, + T.floormod(bz, 14), + 1, + ((((by * 2) + tz) * 4) + o_inner), + 1, + 0, + 16, + 0, + 16, + dtype="handle", + ), + ) + T.evaluate( + T.tvm_store_matrix_sync( + buffer_4.data, + 16, + 16, + 16, + T.floordiv(buffer_4.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + buffer_5.data, + buffer_5.elem_offset, + 256, + 2, + dtype="handle", + ), + 16, + "row_major", + dtype="handle", + ) + ) + return func -def test_opt_conv_tensorcore_normalize(): - mod = opt_conv_tensorcore_normalize - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@T.prim_func -def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) - # body - A_1 = T.match_buffer( - A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_1 = T.match_buffer( - W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Conv_1 = T.match_buffer( - Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) - bx = T.env_thread("blockIdx.x") - by = T.env_thread("blockIdx.y") - bz = T.env_thread("blockIdx.z") - tx = T.env_thread("threadIdx.x") - ty = T.env_thread("threadIdx.y") - tz = T.env_thread("threadIdx.z") - T.launch_thread(bz, 196) - Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") - Apad_shared = T.allocate([12288], "float16", "shared") - W_shared = T.allocate([12288], "float16", "shared") - Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") - W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") - T.launch_thread(bx, 2) - T.launch_thread(by, 4) - T.launch_thread(ty, 4) - T.launch_thread(tz, 2) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, T.float32(0), dtype="handle") - ) - T.evaluate( - T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, T.float32(0), dtype="handle") - ) - for ic_outer in T.serial(0, 8): - for kh in T.serial(0, 3): - for ax2 in T.serial(0, 3): - with T.launch_thread(tx, 32): - Apad_shared[ - ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) - ] = T.if_then_else( - ( + +def opt_conv_tensorcore_lower(): + @T.prim_func + def func(A: T.handle, W: T.handle, Conv: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # body + A_1 = T.match_buffer( + A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + W_1 = T.match_buffer( + W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 + ) + Conv_1 = T.match_buffer( + Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 + ) + bx = T.env_thread("blockIdx.x") + by = T.env_thread("blockIdx.y") + bz = T.env_thread("blockIdx.z") + tx = T.env_thread("threadIdx.x") + ty = T.env_thread("threadIdx.y") + tz = T.env_thread("threadIdx.z") + T.launch_thread(bz, 196) + Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") + Apad_shared = T.allocate([12288], "float16", "shared") + W_shared = T.allocate([12288], "float16", "shared") + Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") + T.launch_thread(bx, 2) + T.launch_thread(by, 4) + T.launch_thread(ty, 4) + T.launch_thread(tz, 2) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 0, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 1, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 2, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 3, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 4, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 5, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 6, T.float32(0), dtype="handle" + ) + ) + T.evaluate( + T.tvm_fill_fragment( + Conv_wmma_accumulator.data, 16, 16, 16, 7, T.float32(0), dtype="handle" + ) + ) + for ic_outer in T.serial(0, 8): + for kh in T.serial(0, 3): + for ax2 in T.serial(0, 3): + with T.launch_thread(tx, 32): + Apad_shared[ + ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61440 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) - ] = T.if_then_else( - ( + - 61440 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61408 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) - ] = T.if_then_else( - ( + - 61408 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( + ( + ( + ( + ( + ( + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) + ) + + (bz * 4096) + ) + + (ax2 * 4096) + ) + + (ic_outer * 512) + ) + + tx + ) + - 61376 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) + ] = T.if_then_else( ( + ( + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61376 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) - ] = T.if_then_else( - ( + - 61344 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61344 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) - ] = T.if_then_else( - ( + - 61312 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61312 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) - ] = T.if_then_else( - ( + - 61280 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61280 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) - ] = T.if_then_else( - ( + - 61248 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61248 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) - ] = T.if_then_else( - ( + - 61216 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61216 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) - ] = T.if_then_else( - ( + - 61184 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61184 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) - ] = T.if_then_else( - ( + - 61152 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61152 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) - ] = T.if_then_else( - ( + - 61120 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61120 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) - ] = T.if_then_else( - ( + - 61088 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61088 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) - ] = T.if_then_else( - ( + - 61056 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61056 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) - ] = T.if_then_else( - ( + - 61024 + ), + ], + T.float16(0), + dtype="float16", + ) + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61024 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): + - 60992 + ), + ], + T.float16(0), + dtype="float16", + ) + T.launch_thread(tx, 32) Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) ] = T.if_then_else( ( ( @@ -1557,947 +1657,922 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ) + tx ) - - 60992 + - 60960 ), ], T.float16(0), dtype="float16", ) - T.launch_thread(tx, 32) - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) - ] = T.if_then_else( - ( - ( - ((1 <= (T.floordiv(bz, 14) + kh)) and ((T.floordiv(bz, 14) + kh) < 15)) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + with T.launch_thread(tx, 32): + W_shared[T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8)] = W_1[ + T.ramp( ( ( ( - ( - ( - (((bx * 6422528) + (ty * 1605632)) + (tz * 802816)) - + (kh * 57344) - ) - + (bz * 4096) - ) - + (ax2 * 4096) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (ic_outer * 512) + + (tz * 256) ) - + tx - ) - - 60960 - ), - ], - T.float16(0), - dtype="float16", - ) - with T.launch_thread(tx, 32): - W_shared[T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8)] = W_1[ - T.ramp( - ( - ( - ((((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + (ty * 512)) - + (tz * 256) - ) - + (tx * 8) - ), - 1, - 8, - ), - T.broadcast(True, 8), - ] - with T.launch_thread(tx, 32): - W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8)] = W_1[ - T.ramp( - ( + + (tx * 8) + ), + 1, + 8, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8)] = W_1[ + T.ramp( ( ( ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) + ( + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) + ) + + (tz * 256) ) - + (tz * 256) + + (tx * 8) ) - + (tx * 8) - ) - + 8192 - ), - 1, - 8, - ), - T.broadcast(True, 8), - ] - with T.launch_thread(tx, 32): - W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8)] = W_1[ - T.ramp( - ( + + 8192 + ), + 1, + 8, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8)] = W_1[ + T.ramp( ( ( ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) + ( + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) + ) + + (tz * 256) ) - + (tz * 256) + + (tx * 8) ) - + (tx * 8) - ) - + 131072 - ), - 1, - 8, - ), - T.broadcast(True, 8), - ] - with T.launch_thread(tx, 32): - W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8)] = W_1[ - T.ramp( - ( + + 131072 + ), + 1, + 8, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8)] = W_1[ + T.ramp( ( ( ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) + ( + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) + ) + + (tz * 256) ) - + (tz * 256) + + (tx * 8) ) - + (tx * 8) - ) - + 139264 - ), - 1, - 8, - ), - T.broadcast(True, 8), - ] - with T.launch_thread(tx, 32): - W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8)] = W_1[ - T.ramp( - ( + + 139264 + ), + 1, + 8, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8)] = W_1[ + T.ramp( ( ( ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) + ( + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) + ) + + (tz * 256) ) - + (tz * 256) + + (tx * 8) ) - + (tx * 8) - ) - + 262144 - ), - 1, - 8, - ), - T.broadcast(True, 8), - ] - with T.launch_thread(tx, 32): - W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8)] = W_1[ - T.ramp( - ( + + 262144 + ), + 1, + 8, + ) + ] + with T.launch_thread(tx, 32): + W_shared[T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8)] = W_1[ + T.ramp( ( ( ( - (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) - + (ty * 512) + ( + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) + ) + + (tz * 256) ) - + (tz * 256) + + (tx * 8) ) - + (tx * 8) - ) - + 270336 - ), - 1, - 8, - ), - T.broadcast(True, 8), - ] - for ic_inner in T.serial(0, 2): - for kw in T.serial(0, 3): - T.evaluate( - T.tvm_load_matrix_sync( - Apad_shared_wmma_matrix_a, - 16, - 16, - 16, - 0, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - Apad_shared, - (((ty * 3072) + (kw * 512)) + (ic_inner * 256)), - 256, - 1, - dtype="handle", + + 270336 ), - 16, - "row_major", - dtype="handle", - ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - Apad_shared_wmma_matrix_a, - 16, - 16, - 16, 1, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - Apad_shared, - ((((ty * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), - 256, - 1, + 8, + ) + ] + for ic_inner in T.serial(0, 2): + for kw in T.serial(0, 3): + T.evaluate( + T.tvm_load_matrix_sync( + Apad_shared_wmma_matrix_a.data, + 16, + 16, + 16, + 0, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + Apad_shared.data, + (((ty * 3072) + (kw * 512)) + (ic_inner * 256)), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 0, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - (((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)), - 256, + T.evaluate( + T.tvm_load_matrix_sync( + Apad_shared_wmma_matrix_a.data, + 16, + 16, + 16, 1, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + Apad_shared.data, + ((((ty * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 1, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 256), - 256, - 1, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, + 0, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + (((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 2, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 512), - 256, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, 1, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 256), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_load_matrix_sync( - W_shared_wmma_matrix_b, - 16, - 16, - 16, - 3, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 768), - 256, - 1, + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, + 2, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 512), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 0, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 0, - Conv_wmma_accumulator, - 0, - dtype="handle", + T.evaluate( + T.tvm_load_matrix_sync( + W_shared_wmma_matrix_b.data, + 16, + 16, + 16, + 3, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + W_shared.data, + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 768), + 256, + 1, + dtype="handle", + ), + 16, + "row_major", + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 1, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 1, - Conv_wmma_accumulator, - 1, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 0, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 0, + Conv_wmma_accumulator.data, + 0, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 2, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 2, - Conv_wmma_accumulator, - 2, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 1, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 1, + Conv_wmma_accumulator.data, + 1, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 3, - Apad_shared_wmma_matrix_a, - 0, - W_shared_wmma_matrix_b, - 3, - Conv_wmma_accumulator, - 3, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 2, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 2, + Conv_wmma_accumulator.data, + 2, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 4, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 0, - Conv_wmma_accumulator, - 4, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 3, + Apad_shared_wmma_matrix_a.data, + 0, + W_shared_wmma_matrix_b.data, + 3, + Conv_wmma_accumulator.data, + 3, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 5, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 1, - Conv_wmma_accumulator, - 5, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 4, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 0, + Conv_wmma_accumulator.data, + 4, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 6, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 2, - Conv_wmma_accumulator, - 6, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 5, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 1, + Conv_wmma_accumulator.data, + 5, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_mma_sync( - Conv_wmma_accumulator, - 7, - Apad_shared_wmma_matrix_a, - 1, - W_shared_wmma_matrix_b, - 3, - Conv_wmma_accumulator, - 7, - dtype="handle", + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 6, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 2, + Conv_wmma_accumulator.data, + 6, + dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 0, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - (((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + (tz * 1024)), - 256, - 2, - dtype="handle", - ), - 16, - "row_major", - dtype="handle", - ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 1, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_mma_sync( + Conv_wmma_accumulator.data, + 7, + Apad_shared_wmma_matrix_a.data, + 1, + W_shared_wmma_matrix_b.data, + 3, + Conv_wmma_accumulator.data, + 7, + dtype="handle", + ) + ) + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 0, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + (tz * 1024) - ) - + 256 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 2, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 1, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 512 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 256 + ), + 256, + 2, + dtype="handle", ), - 256, + 16, + "row_major", + dtype="handle", + ) + ) + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, 2, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, + ( + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 512 + ), + 256, + 2, + dtype="handle", + ), + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 3, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 3, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 768 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 768 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 4, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 4, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1605632 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1605632 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 5, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 5, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1605888 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1605888 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 6, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 6, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1606144 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1606144 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) ) - ) - T.evaluate( - T.tvm_store_matrix_sync( - Conv_wmma_accumulator, - 16, - 16, - 16, - 7, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - Conv_1.data, - ( + T.evaluate( + T.tvm_store_matrix_sync( + Conv_wmma_accumulator.data, + 16, + 16, + 16, + 7, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + Conv_1.data, ( - ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) - + (tz * 1024) - ) - + 1606400 + ( + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) + ) + + 1606400 + ), + 256, + 2, + dtype="handle", ), - 256, - 2, + 16, + "row_major", dtype="handle", - ), - 16, - "row_major", - dtype="handle", + ) + ) + + return func + + +def opt_conv_tensorcore_mod_host(): + @T.prim_func + def opt_conv_tensorcore_mod_host( + args: T.handle, + arg_type_ids: T.Buffer[(3,), "int32"], + num_args: T.int32, + out_ret_value: T.handle, + out_ret_tcode: T.handle, + resource_handle: T.handle, + ) -> T.int32: + # function attr dict + T.func_attr( + { + "tir.noalias": True, + "global_symbol": "default_function", + "tir.is_entry_func": True, + "calling_conv": 1, + } ) - ) - - -def test_opt_conv_tensorcore_lower(): - mod = opt_conv_tensorcore_lower - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) - - -@T.prim_func -def opt_conv_tensorcore_mod_host( - args: T.handle, - arg_type_ids: T.handle, - num_args: T.int32, - out_ret_value: T.handle, - out_ret_tcode: T.handle, - resource_handle: T.handle, -) -> T.int32: - # function attr dict - T.func_attr( - { - "tir.noalias": True, - "global_symbol": "default_function", - "tir.is_entry_func": True, - "calling_conv": 1, - } - ) - # body - stack_tcode: T.handle = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") - stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") - assert num_args == 3, "default_function: num_args should be 3" - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: T.int32 = arg_type_ids[0] - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: T.int32 = arg_type_ids[1] - arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: T.int32 = arg_type_ids[2] - A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - T.attr(A, "storage_alignment", 128) - arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") - arg0_strides: T.handle = T.tvm_struct_get(arg0, 0, 3, dtype="handle") - dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - T.attr(W, "storage_alignment", 128) - arg1_shape: T.handle = T.tvm_struct_get(arg1, 0, 2, dtype="handle") - arg1_strides: T.handle = T.tvm_struct_get(arg1, 0, 3, dtype="handle") - Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") - T.attr(Conv, "storage_alignment", 128) - arg2_shape: T.handle = T.tvm_struct_get(arg2, 0, 2, dtype="handle") - arg2_strides: T.handle = T.tvm_struct_get(arg2, 0, 3, dtype="handle") - assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( - arg0_code == 4 - ), "default_function: Expect arg[0] to be pointer" - assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( - arg1_code == 4 - ), "default_function: Expect arg[1] to be pointer" - assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( - arg2_code == 4 - ), "default_function: Expect arg[2] to be pointer" - assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" - assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" - assert ( - (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(16)) - ) and ( - T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg0.dtype is expected to be float16" - assert 16 == T.cast( - arg0_shape[0], "int32" - ), "Argument arg0.shape[0] has an unsatisfied constraint" - assert 14 == T.cast( - arg0_shape[1], "int32" - ), "Argument arg0.shape[1] has an unsatisfied constraint" - assert 14 == T.cast( - arg0_shape[2], "int32" - ), "Argument arg0.shape[2] has an unsatisfied constraint" - assert 16 == T.cast( - arg0_shape[3], "int32" - ), "Argument arg0.shape[3] has an unsatisfied constraint" - assert 16 == T.cast( - arg0_shape[4], "int32" - ), "Argument arg0.shape[4] has an unsatisfied constraint" - assert 16 == T.cast( - arg0_shape[5], "int32" - ), "Argument arg0.shape[5] has an unsatisfied constraint" - if not (T.isnullptr(arg0_strides, dtype="bool")): + # body + stack_tcode_data: T.Ptr[T.int32] = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") + stack_tcode = T.buffer_decl([9], "int32", data=stack_tcode_data) + stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") + assert num_args == 3, "default_function: num_args should be 3" + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: T.int32 = arg_type_ids[0] + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: T.int32 = arg_type_ids[1] + arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: T.int32 = arg_type_ids[2] + + A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + T.attr(A, "storage_alignment", 128) + arg0_shape_data: T.Ptr[T.int64] = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_shape = T.buffer_decl([6], "int64", data=arg0_shape_data) + arg0_strides_data: T.Ptr[T.int64] = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + arg0_strides = T.buffer_decl([6], "int64", data=arg0_strides_data) + + dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") + + W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + T.attr(W, "storage_alignment", 128) + arg1_shape_data: T.Ptr[T.int64] = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_shape = T.buffer_decl([6], "int64", data=arg1_shape_data) + arg1_strides_data: T.Ptr[T.int64] = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + arg1_strides = T.buffer_decl([6], "int64", data=arg1_strides_data) + + Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + T.attr(Conv, "storage_alignment", 128) + arg2_shape_data: T.Ptr[T.int64] = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_shape = T.buffer_decl([6], "int64", data=arg2_shape_data) + arg2_strides_data: T.Ptr[T.int64] = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + arg2_strides = T.buffer_decl([6], "int64", data=arg2_strides_data) + + assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( + arg0_code == 4 + ), "default_function: Expect arg[0] to be pointer" + assert (((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or ( + arg1_code == 4 + ), "default_function: Expect arg[1] to be pointer" + assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( + arg2_code == 4 + ), "default_function: Expect arg[2] to be pointer" + assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" assert ( - ( + (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(16)) + ) and ( + T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg0.dtype is expected to be float16" + assert 16 == T.cast( + arg0_shape[0], "int32" + ), "Argument arg0.shape[0] has an unsatisfied constraint" + assert 14 == T.cast( + arg0_shape[1], "int32" + ), "Argument arg0.shape[1] has an unsatisfied constraint" + assert 14 == T.cast( + arg0_shape[2], "int32" + ), "Argument arg0.shape[2] has an unsatisfied constraint" + assert 16 == T.cast( + arg0_shape[3], "int32" + ), "Argument arg0.shape[3] has an unsatisfied constraint" + assert 16 == T.cast( + arg0_shape[4], "int32" + ), "Argument arg0.shape[4] has an unsatisfied constraint" + assert 16 == T.cast( + arg0_shape[5], "int32" + ), "Argument arg0.shape[5] has an unsatisfied constraint" + if not (T.isnullptr(arg0_strides.data, dtype="bool")): + assert ( ( ( - (1 == T.cast(arg0_strides[5], "int32")) - and (16 == T.cast(arg0_strides[4], "int32")) + ( + (1 == T.cast(arg0_strides[5], "int32")) + and (16 == T.cast(arg0_strides[4], "int32")) + ) + and (256 == T.cast(arg0_strides[3], "int32")) ) - and (256 == T.cast(arg0_strides[3], "int32")) + and (4096 == T.cast(arg0_strides[2], "int32")) ) - and (4096 == T.cast(arg0_strides[2], "int32")) - ) - and (57344 == T.cast(arg0_strides[1], "int32")) - ) and ( - 802816 == T.cast(arg0_strides[0], "int32") - ), "arg0.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg0, 0, 8, dtype="uint64" - ), "Argument arg0.byte_offset has an unsatisfied constraint" - assert 2 == T.tvm_struct_get( - arg0, 0, 10, dtype="int32" - ), "Argument arg0.device_type has an unsatisfied constraint" - assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" - assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" - assert ( - (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(16)) - ) and ( - T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg1.dtype is expected to be float16" - assert 3 == T.cast( - arg1_shape[0], "int32" - ), "Argument arg1.shape[0] has an unsatisfied constraint" - assert 3 == T.cast( - arg1_shape[1], "int32" - ), "Argument arg1.shape[1] has an unsatisfied constraint" - assert 16 == T.cast( - arg1_shape[2], "int32" - ), "Argument arg1.shape[2] has an unsatisfied constraint" - assert 32 == T.cast( - arg1_shape[3], "int32" - ), "Argument arg1.shape[3] has an unsatisfied constraint" - assert 16 == T.cast( - arg1_shape[4], "int32" - ), "Argument arg1.shape[4] has an unsatisfied constraint" - assert 16 == T.cast( - arg1_shape[5], "int32" - ), "Argument arg1.shape[5] has an unsatisfied constraint" - if not (T.isnullptr(arg1_strides, dtype="bool")): + and (57344 == T.cast(arg0_strides[1], "int32")) + ) and ( + 802816 == T.cast(arg0_strides[0], "int32") + ), "arg0.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg0, 0, 8, dtype="uint64" + ), "Argument arg0.byte_offset has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg0, 0, 10, dtype="int32" + ), "Argument arg0.device_type has an unsatisfied constraint" + assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" assert ( - ( + (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(16)) + ) and ( + T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg1.dtype is expected to be float16" + assert 3 == T.cast( + arg1_shape[0], "int32" + ), "Argument arg1.shape[0] has an unsatisfied constraint" + assert 3 == T.cast( + arg1_shape[1], "int32" + ), "Argument arg1.shape[1] has an unsatisfied constraint" + assert 16 == T.cast( + arg1_shape[2], "int32" + ), "Argument arg1.shape[2] has an unsatisfied constraint" + assert 32 == T.cast( + arg1_shape[3], "int32" + ), "Argument arg1.shape[3] has an unsatisfied constraint" + assert 16 == T.cast( + arg1_shape[4], "int32" + ), "Argument arg1.shape[4] has an unsatisfied constraint" + assert 16 == T.cast( + arg1_shape[5], "int32" + ), "Argument arg1.shape[5] has an unsatisfied constraint" + if not (T.isnullptr(arg1_strides.data, dtype="bool")): + assert ( ( ( - (1 == T.cast(arg1_strides[5], "int32")) - and (16 == T.cast(arg1_strides[4], "int32")) + ( + (1 == T.cast(arg1_strides[5], "int32")) + and (16 == T.cast(arg1_strides[4], "int32")) + ) + and (256 == T.cast(arg1_strides[3], "int32")) ) - and (256 == T.cast(arg1_strides[3], "int32")) + and (8192 == T.cast(arg1_strides[2], "int32")) ) - and (8192 == T.cast(arg1_strides[2], "int32")) - ) - and (131072 == T.cast(arg1_strides[1], "int32")) - ) and ( - 393216 == T.cast(arg1_strides[0], "int32") - ), "arg1.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg1, 0, 8, dtype="uint64" - ), "Argument arg1.byte_offset has an unsatisfied constraint" - assert 2 == T.tvm_struct_get( - arg1, 0, 10, dtype="int32" - ), "Argument arg1.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg1, 0, 9, dtype="int32" - ), "Argument arg1.device_id has an unsatisfied constraint" - assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" - assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" - assert ( - (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) - and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) - ) and ( - T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) - ), "arg2.dtype is expected to be float32" - assert 16 == T.cast( - arg2_shape[0], "int32" - ), "Argument arg2.shape[0] has an unsatisfied constraint" - assert 14 == T.cast( - arg2_shape[1], "int32" - ), "Argument arg2.shape[1] has an unsatisfied constraint" - assert 14 == T.cast( - arg2_shape[2], "int32" - ), "Argument arg2.shape[2] has an unsatisfied constraint" - assert 32 == T.cast( - arg2_shape[3], "int32" - ), "Argument arg2.shape[3] has an unsatisfied constraint" - assert 16 == T.cast( - arg2_shape[4], "int32" - ), "Argument arg2.shape[4] has an unsatisfied constraint" - assert 16 == T.cast( - arg2_shape[5], "int32" - ), "Argument arg2.shape[5] has an unsatisfied constraint" - if not (T.isnullptr(arg2_strides, dtype="bool")): + and (131072 == T.cast(arg1_strides[1], "int32")) + ) and ( + 393216 == T.cast(arg1_strides[0], "int32") + ), "arg1.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg1, 0, 8, dtype="uint64" + ), "Argument arg1.byte_offset has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg1, 0, 10, dtype="int32" + ), "Argument arg1.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg1, 0, 9, dtype="int32" + ), "Argument arg1.device_id has an unsatisfied constraint" + assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" assert ( - ( + (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) + ) and ( + T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) + ), "arg2.dtype is expected to be float32" + assert 16 == T.cast( + arg2_shape[0], "int32" + ), "Argument arg2.shape[0] has an unsatisfied constraint" + assert 14 == T.cast( + arg2_shape[1], "int32" + ), "Argument arg2.shape[1] has an unsatisfied constraint" + assert 14 == T.cast( + arg2_shape[2], "int32" + ), "Argument arg2.shape[2] has an unsatisfied constraint" + assert 32 == T.cast( + arg2_shape[3], "int32" + ), "Argument arg2.shape[3] has an unsatisfied constraint" + assert 16 == T.cast( + arg2_shape[4], "int32" + ), "Argument arg2.shape[4] has an unsatisfied constraint" + assert 16 == T.cast( + arg2_shape[5], "int32" + ), "Argument arg2.shape[5] has an unsatisfied constraint" + if not (T.isnullptr(arg2_strides.data, dtype="bool")): + assert ( ( ( - (1 == T.cast(arg2_strides[5], "int32")) - and (16 == T.cast(arg2_strides[4], "int32")) + ( + (1 == T.cast(arg2_strides[5], "int32")) + and (16 == T.cast(arg2_strides[4], "int32")) + ) + and (256 == T.cast(arg2_strides[3], "int32")) ) - and (256 == T.cast(arg2_strides[3], "int32")) + and (8192 == T.cast(arg2_strides[2], "int32")) ) - and (8192 == T.cast(arg2_strides[2], "int32")) + and (114688 == T.cast(arg2_strides[1], "int32")) + ) and ( + 1605632 == T.cast(arg2_strides[0], "int32") + ), "arg2.strides: expected to be compact array" + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( + arg2, 0, 8, dtype="uint64" + ), "Argument arg2.byte_offset has an unsatisfied constraint" + assert 2 == T.tvm_struct_get( + arg2, 0, 10, dtype="int32" + ), "Argument arg2.device_type has an unsatisfied constraint" + assert dev_id == T.tvm_struct_get( + arg2, 0, 9, dtype="int32" + ), "Argument arg2.device_id has an unsatisfied constraint" + T.evaluate(T.tvm_struct_set(stack_value, 0, 12, T.cast(2, "int64"), dtype="int32")) + stack_tcode[0] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 1, 12, T.cast(dev_id, "int64"), dtype="int32")) + stack_tcode[1] = 0 + T.evaluate( + T.tvm_call_packed_lowered( + "__tvm_set_device", stack_value, stack_tcode.data, 0, 2, dtype="int32" + ) + ) + T.attr(0, "compute_scope", "default_function_compute_") + T.evaluate(T.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) + stack_tcode[0] = 3 + T.evaluate(T.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) + stack_tcode[1] = 3 + T.evaluate(T.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) + stack_tcode[2] = 3 + T.evaluate(T.tvm_struct_set(stack_value, 3, 12, T.cast(196, "int64"), dtype="int32")) + stack_tcode[3] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 4, 12, T.cast(2, "int64"), dtype="int32")) + stack_tcode[4] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 5, 12, T.cast(4, "int64"), dtype="int32")) + stack_tcode[5] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 6, 12, T.cast(4, "int64"), dtype="int32")) + stack_tcode[6] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 7, 12, T.cast(2, "int64"), dtype="int32")) + stack_tcode[7] = 0 + T.evaluate(T.tvm_struct_set(stack_value, 8, 12, T.cast(32, "int64"), dtype="int32")) + stack_tcode[8] = 0 + T.evaluate( + T.tvm_call_packed_lowered( + "default_function_kernel0", stack_value, stack_tcode.data, 0, 9, dtype="int32" ) - and (114688 == T.cast(arg2_strides[1], "int32")) - ) and ( - 1605632 == T.cast(arg2_strides[0], "int32") - ), "arg2.strides: expected to be compact array" - T.evaluate(0) - assert T.uint64(0) == T.tvm_struct_get( - arg2, 0, 8, dtype="uint64" - ), "Argument arg2.byte_offset has an unsatisfied constraint" - assert 2 == T.tvm_struct_get( - arg2, 0, 10, dtype="int32" - ), "Argument arg2.device_type has an unsatisfied constraint" - assert dev_id == T.tvm_struct_get( - arg2, 0, 9, dtype="int32" - ), "Argument arg2.device_id has an unsatisfied constraint" - T.evaluate(T.tvm_struct_set(stack_value, 0, 12, T.cast(2, "int64"), dtype="int32")) - stack_tcode[0] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 1, 12, T.cast(dev_id, "int64"), dtype="int32")) - stack_tcode[1] = 0 - T.evaluate( - T.tvm_call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32") - ) - T.attr(0, "compute_scope", "default_function_compute_") - T.evaluate(T.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) - stack_tcode[0] = 3 - T.evaluate(T.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) - stack_tcode[1] = 3 - T.evaluate(T.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) - stack_tcode[2] = 3 - T.evaluate(T.tvm_struct_set(stack_value, 3, 12, T.cast(196, "int64"), dtype="int32")) - stack_tcode[3] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 4, 12, T.cast(2, "int64"), dtype="int32")) - stack_tcode[4] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 5, 12, T.cast(4, "int64"), dtype="int32")) - stack_tcode[5] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 6, 12, T.cast(4, "int64"), dtype="int32")) - stack_tcode[6] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 7, 12, T.cast(2, "int64"), dtype="int32")) - stack_tcode[7] = 0 - T.evaluate(T.tvm_struct_set(stack_value, 8, 12, T.cast(32, "int64"), dtype="int32")) - stack_tcode[8] = 0 - T.evaluate( - T.tvm_call_packed_lowered( - "default_function_kernel0", stack_value, stack_tcode, 0, 9, dtype="int32" ) - ) - -def test_opt_conv_tensorcore_mod_host(): - mod = opt_conv_tensorcore_mod_host - rt_mod = tvm.script.from_source(mod.script(show_meta=True)) - tvm.ir.assert_structural_equal(mod, rt_mod, True) + return opt_conv_tensorcore_mod_host -@T.prim_func -def vthread_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") +def vthread_func(): + @T.prim_func + def vthread_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.allocate([16], "float32", "local") + for j in range(16): + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) + for j in range(16): + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * T.float32(2) - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") + return vthread_func - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") - for j in range(16): - B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + T.float32(1) - for j in range(16): - C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * T.float32(2) +def matmul(): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) -def test_vthread(): - func = vthread_func - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + return matmul -@T.prim_func -def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - for i, j, k in T.grid(128, 128, 128): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): +def matmul_original(): + @T.prim_func + def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j in T.grid(128, 128): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@T.prim_func -def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + return matmul_original - for i, j in T.grid(128, 128): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = T.float32(0) - for k in range(128): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +def element_wise(): + @T.prim_func + def element_wise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) -@T.prim_func -def element_wise(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128), "float32") - C = T.match_buffer(c, (128, 128), "float32") - B = T.alloc_buffer((128, 128), "float32") + return element_wise - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * T.float32(2) - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, vj] + T.float32(1) +def predicate(): + @T.prim_func + def predicate(b: T.handle, c: T.handle) -> None: + B = T.match_buffer(b, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") -@T.prim_func -def predicate(b: T.handle, c: T.handle) -> None: - B = T.match_buffer(b, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + for i, jo, ji in T.grid(16, 4, 5): + with T.block("update"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, jo * 4 + ji) + T.where(jo * 4 + ji < 16) + C[vi, vj] = B[vi, vj] + T.float32(1) - for i, jo, ji in T.grid(16, 4, 5): - with T.block("update"): - vi = T.axis.S(16, i) - vj = T.axis.S(16, jo * 4 + ji) - T.where(jo * 4 + ji < 16) - C[vi, vj] = B[vi, vj] + T.float32(1) + return predicate def test_module_define(): - func1 = tvm.ir.IRModule({"matmul": matmul})["matmul"] - func2 = tvm.ir.IRModule({"element_wise": element_wise})["element_wise"] - func3 = tvm.ir.IRModule({"predicate": predicate})["predicate"] + func1 = tvm.ir.IRModule({"matmul": matmul()})["matmul"] + func2 = tvm.ir.IRModule({"element_wise": element_wise()})["element_wise"] + func3 = tvm.ir.IRModule({"predicate": predicate()})["predicate"] mod1 = tvm.ir.IRModule({"func1": func1, "func2": func2, "func3": func3}) - mod2 = tvm.ir.IRModule({"func1": matmul, "func2": element_wise, "func3": predicate}) + mod2 = tvm.ir.IRModule({"func1": matmul(), "func2": element_wise(), "func3": predicate()}) tvm.ir.assert_structural_equal(mod1, mod2) -def test_matmul(): - func = matmul - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - def test_matmul_original(): - func = matmul_original + func = matmul_original() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2511,7 +2586,7 @@ def test_matmul_original(): def test_element_wise(): - func = element_wise + func = element_wise() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2527,7 +2602,7 @@ def test_element_wise(): def test_predicate(): - func = predicate + func = predicate() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2538,20 +2613,23 @@ def test_predicate(): assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.Block) -@T.prim_func -def for_thread_binding(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (16, 16), "float32") +def for_thread_binding(): + @T.prim_func + def for_thread_binding(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") + + for i in T.thread_binding(0, 16, thread="threadIdx.x"): + for j in T.thread_binding( + 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} + ): + A[i, j] = B[i, j] + T.float32(1) - for i in T.thread_binding(0, 16, thread="threadIdx.x"): - for j in T.thread_binding( - 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} - ): - A[i, j] = B[i, j] + T.float32(1) + return for_thread_binding def test_for_thread_binding(): - func = for_thread_binding + func = for_thread_binding() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2564,25 +2642,28 @@ def test_for_thread_binding(): assert rt_func.body.body.annotations["attr_key"] == "attr_value" -@T.prim_func -def match_buffer_region(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16, 16), "float32") - B = T.match_buffer(b, (1), "float32") +def match_buffer_region(): + @T.prim_func + def match_buffer_region(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16), "float32") + B = T.match_buffer(b, (1), "float32") - for i, j in T.grid(16, 4): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) - for ii in range(4): - with T.block(): - vii = T.axis.S(4, ii) - D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) - for i, j in T.grid(4, 4): - B[0] += D[i, 0, j] + for i, j in T.grid(16, 4): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + for ii in range(4): + with T.block(): + vii = T.axis.S(4, ii) + D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in T.grid(4, 4): + B[0] += D[i, 0, j] + + return match_buffer_region def test_match_buffer_region(): - func = match_buffer_region + func = match_buffer_region() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2605,26 +2686,29 @@ def test_match_buffer_region(): tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) -@T.prim_func -def block_elements(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (1, 1), "float32") +def block_elements(): + @T.prim_func + def block_elements(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (1, 1), "float32") + + with T.block("update"): + vi = T.axis.S(1, 0) + T.where(True) + T.reads(A[0:16, 0:16]) + T.writes(B[0, 0]) + T.block_attr({"attr_key": "attr_value"}) + C = T.alloc_buffer((4, 4), dtype="float32") + D = T.match_buffer(A[0:4, 0], (4, 1)) + with T.init(): + B[0, 0] = T.float32(0) + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] - with T.block("update"): - vi = T.axis.S(1, 0) - T.where(True) - T.reads(A[0:16, 0:16]) - T.writes(B[0, 0]) - T.block_attr({"attr_key": "attr_value"}) - C = T.alloc_buffer((4, 4), dtype="float32") - D = T.match_buffer(A[0:4, 0], (4, 1)) - with T.init(): - B[0, 0] = T.float32(0) - B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] + return block_elements def test_block_elements(): - func = block_elements + func = block_elements() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2638,26 +2722,29 @@ def test_block_elements(): assert block.annotations["attr_key"] == "attr_value" -@T.prim_func -def opaque_block(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (16, 16), "float32") +def opaque_block(): + @T.prim_func + def opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") - for i in range(16): - for j in range(16): - with T.block(): - T.reads([]) - T.writes(A[i, j]) - A[i, j] = T.float32(0) - with T.block(): - T.reads([A[i, 0:16]]) - T.writes([B[i, 0:16]]) + for i in range(16): for j in range(16): - B[i, j] = A[i, j] + with T.block(): + T.reads([]) + T.writes(A[i, j]) + A[i, j] = T.float32(0) + with T.block(): + T.reads([A[i, 0:16]]) + T.writes([B[i, 0:16]]) + for j in range(16): + B[i, j] = A[i, j] + + return opaque_block def test_opaque_block(): - func = opaque_block + func = opaque_block() rt_func = tvm.script.from_source(func.script(show_meta=True)) tvm.ir.assert_structural_equal(func, rt_func) @@ -2673,124 +2760,106 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 -@T.prim_func -def rank0(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - A[()] = 2 - B[()] = A[()] - - -def test_rank0_buffers(): - func = rank0 - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - -@T.prim_func -def rank0_block(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - B[0] = A.data[0] - - with T.block("update") as []: - T.reads([A[()]]) - T.writes([B[()]]) - for i in range(1): - B[()] = A[()] - +def rank0(): + @T.prim_func + def rank0(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") + A[()] = 2 + B[()] = A[()] -def test_rank0_blocks(): - func = rank0_block - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + return rank0 -@T.prim_func -def select(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - A[()] = T.Select(True, 1, 2) +def rank0_block(): + @T.prim_func + def rank0_block(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") + B[0] = A[0] + with T.block("update") as []: + T.reads([A[()]]) + T.writes([B[()]]) + for i in range(1): + B[()] = A[()] -def test_select(): - func = select - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + return rank0_block -@T.prim_func -def minmax(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - A[()] = T.min(1, 2) - A[()] = T.max(1, 2) +def select(): + @T.prim_func + def select(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.Select(True, 1, 2) + return select -def test_minmax(): - func = minmax - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) +def minmax(): + @T.prim_func + def minmax(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.min(1, 2) + A[()] = T.max(1, 2) -@T.prim_func -def abs(a: T.handle) -> None: - A = T.match_buffer(a, (128, 128), "float32") + return minmax - for i, j in T.grid(128, 128): - with T.block("A"): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = T.abs(A[vi, vj]) +def abs(): + @T.prim_func + def abs(a: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") -def test_abs(): - func = abs - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = T.abs(A[vi, vj]) + return abs -@T.prim_func -def constant_folding(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - A[()] = T.min(2.2, 5.2) - A[()] = T.max(T.float32(2.2), T.float32(T.float32(5.2))) - A[()] = T.min(2.2, 5.0) +def constant_folding(): + @T.prim_func + def constant_folding(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.min(2.2, 5.2) + A[()] = T.max(T.float32(2.2), T.float32(T.float32(5.2))) + A[()] = T.min(2.2, 5.0) -def test_constant_folding(): - func = constant_folding - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) + return constant_folding -@T.prim_func -def simplify_bracket() -> None: - a = T.var("int32") - b = T.var("int32") - c = T.var("int32") - d = T.var("int32") - T.evaluate(a + b * (c + d)) +def simplify_bracket(): + @T.prim_func + def simplify_bracket() -> None: + a = T.var("int32") + b = T.var("int32") + c = T.var("int32") + d = T.var("int32") + T.evaluate(a + b * (c + d)) + return simplify_bracket -def test_simplify_bracket(): - func = simplify_bracket - out_str = func.script(show_meta=True) - assert out_str.count("a + b * (c + d)") == 1 +def var_with_same_name(): + @T.prim_func + def var_with_same_name(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 0 + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 0 -@T.prim_func -def var_with_same_name(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = 0 - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = 0 + return var_with_same_name def test_same_name_var(): - func = var_with_same_name + func = var_with_same_name() out_str = func.script(tir_prefix="T", show_meta=True) rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) @@ -2804,124 +2873,115 @@ def test_same_name_var(): assert out_str.find("i_") == -1 -@T.prim_func -def while_loop(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (16,), "float32") - B = T.match_buffer(b, (16,), "float32") - i = T.alloc_buffer((), "int32", scope="local") - for ii in range(16): - with T.block(): - vi = T.axis.S(16, ii) - B[vi] = 0 - while i[()] < 10: - for j in range(16): - B[j] += A[j] - +def while_loop(): + @T.prim_func + def while_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + i = T.alloc_buffer((), "int32", scope="local") + for ii in range(16): + with T.block(): + vi = T.axis.S(16, ii) + B[vi] = 0 + while i[()] < 10: + for j in range(16): + B[j] += A[j] -def test_while_loop(): - rt_func = tvm.script.from_source(while_loop.script(show_meta=True)) - tvm.ir.assert_structural_equal(while_loop, rt_func) + return while_loop # fmt: off -@T.prim_func -def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - # body - tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) - for ax0_ax1_fused_4 in T.serial(0, 56): - for ax2_4 in T.serial(0, 56): - for ax3_init in T.serial(0, 64): - tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) - for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): - tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) - for ax0_ax1_fused_5 in T.serial(0, 56): - for ax2_5, ax3_3 in T.grid(56, 64): - T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") +def primfunc_with_allocate_annotations(): + @T.prim_func + def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") + + return primfunc_with_allocate_annotations # fmt: on -def test_primfunc_with_allocate_annotations(): - func = primfunc_with_allocate_annotations - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) # fmt: off -@T.prim_func -def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128, 128], dtype="float32") - for i in T.serial(0, 128): - T.launch_thread(threadIdx_x, 128) - reduce_temp0 = T.allocate([1], "float32", "local") - with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) - - -@T.prim_func -def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128, 128], dtype="float32") - for i in T.serial(0, 128): - T.launch_thread(threadIdx_x, 128) - reduce_temp0 = T.allocate([1], "float32", "local") - with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0, threadIdx_x, dtype="handle")) - - -@T.prim_func -def multiple_commreducer() -> None: - normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") - for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cross_thread_reduction"): - T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) - for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_expsum_cross_thread_reduction"): - T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) - T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) -# fmt: on +def comm_reducer_single_reduce_group(): + @T.prim_func + def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0.data, threadIdx_x, dtype="handle")) + return comm_reducer_single_reduce_group -def test_primfunc_with_single_reduce_group_commreducer(): - func = comm_reducer_single_reduce_group - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def comm_reducer_multiple_reduce_groups(): + @T.prim_func + def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), A[i * 128 + threadIdx_x], True, reduce_temp0.data, threadIdx_x, dtype="handle")) -def test_primfunc_with_multiple_reduce_group_commreducer(): - func = comm_reducer_multiple_reduce_groups - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + return comm_reducer_multiple_reduce_groups -def test_primfunc_with_multiple_commreducer(): - func = multiple_commreducer - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def multiple_commreducer(): + @T.prim_func + def multiple_commreducer() -> None: + normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) + + return multiple_commreducer +# fmt: on -@T.prim_func def func_div_mod(): - a = T.var("int32") - b = T.var("int32") - T.evaluate(a // b) - T.evaluate(a % b) - T.evaluate(a / b) - T.evaluate(T.truncmod(a, b)) + @T.prim_func + def func_div_mod(): + a = T.var("int32") + b = T.var("int32") + T.evaluate(a // b) + T.evaluate(a % b) + T.evaluate(a / b) + T.evaluate(T.truncmod(a, b)) + + return func_div_mod def test_div_mod(): - func = func_div_mod + func = func_div_mod() rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func, True) @@ -2931,128 +2991,144 @@ def test_div_mod(): assert isinstance(func.body[3].value, tvm.tir.Mod) -@T.prim_func -def loop_extent_dependent(a: T.handle) -> None: - A = T.match_buffer(a, [], dtype="int32") - for i in T.serial(0, 128): - for j in T.serial(0, i): - A[()] = A[()] + j - - -def test_loop_extent_dependent(): - func = loop_extent_dependent - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) - - -@T.prim_func -def nontrivial_range_axis(a: T.handle) -> None: - A = T.match_buffer(a, (10), "float32") - for i in range(10): - with T.block("block"): - vi = T.axis.spatial((1, 11), i + 1) - A[vi - 1] = A[vi - 1] + 1.0 - - -def test_nontrivial_range_axis(): - func = nontrivial_range_axis - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) - +def loop_extent_dependent(): + @T.prim_func + def loop_extent_dependent(a: T.handle) -> None: + A = T.match_buffer(a, [], dtype="int32") + for i in T.serial(0, 128): + for j in T.serial(0, i): + A[()] = A[()] + j -@T.prim_func -def func_with_target_spec_by_config() -> None: - T.func_attr( - { - "kTarget": T.target( - { - "max_num_threads": 1024, - "arch": "sm_70", - "thread_warp_size": 32, - "kind": "cuda", - "tag": "", - "keys": ["cuda", "gpu"], - } - ) - } - ) - T.evaluate(0) + return loop_extent_dependent -@T.prim_func -def func_with_target_spec_by_str() -> None: - T.func_attr({"kTarget": T.target("nvidia/nvidia-a100")}) - T.evaluate(0) +def nontrivial_range_axis(): + @T.prim_func + def nontrivial_range_axis(a: T.handle) -> None: + A = T.match_buffer(a, (10), "float32") + for i in range(10): + with T.block("block"): + vi = T.axis.spatial((1, 11), i + 1) + A[vi - 1] = A[vi - 1] + 1.0 + return nontrivial_range_axis -def test_func_with_target_spec_by_config(): - func = func_with_target_spec_by_config - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def func_with_target_spec_by_config(): + @T.prim_func + def func_with_target_spec_by_config() -> None: + T.func_attr( + { + "kTarget": T.target( + { + "max_num_threads": 1024, + "arch": "sm_70", + "thread_warp_size": 32, + "kind": "cuda", + "tag": "", + "keys": ["cuda", "gpu"], + } + ) + } + ) + T.evaluate(0) -def test_func_with_target_spec_by_str(): - func = func_with_target_spec_by_str - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + return func_with_target_spec_by_config -@T.prim_func -def func_root_attr(): - with T.block("root"): - T.block_attr({"a": "0"}) +def func_with_target_spec_by_str(): + @T.prim_func + def func_with_target_spec_by_str() -> None: + T.func_attr({"kTarget": T.target("nvidia/nvidia-a100")}) T.evaluate(0) - -def test_root_attr(): - func = func_root_attr - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + return func_with_target_spec_by_str -@T.prim_func -def func_T_ptr_let_statement( - args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 -) -> None: - # The T.Ptr declaration in the parameter list should parse - # correctly, and should be usable as the data pointer in a buffer. - arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) +def func_root_attr(): + @T.prim_func + def func_root_attr(): + with T.block("root"): + T.block_attr({"a": "0"}) + T.evaluate(0) - arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") - arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + return func_root_attr - # Functions that return a "handle" can be assigned to a T.Ptr - # variable. A variable annotated with T.Ptr still has dtype of - # T.handle, but has type annotation as a pointer type. - A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") - # The buffer declaration has a data pointer defined earlier in - # this function. It should only be defined after the data pointer - # has been defined, and should not be hoisted into the header of - # the function as other buffer_decl statements can be. - A = T.buffer_decl([1024], dtype="float32", data=A_data) - B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") - B = T.buffer_decl([1024], dtype="float32", data=B_data) +def func_T_ptr_let_statement(): + @T.prim_func + def func_T_ptr_let_statement( + args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 + ) -> None: + # The T.Ptr declaration in the parameter list should parse + # correctly, and should be usable as the data pointer in a buffer. + arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) - B[0] = A[0] + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + # Functions that return a "handle" can be assigned to a T.Ptr + # variable. A variable annotated with T.Ptr still has dtype of + # T.handle, but has type annotation as a pointer type. + A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") -def test_T_ptr_let_statement(): - func = func_T_ptr_let_statement - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) + # The buffer declaration has a data pointer defined earlier in + # this function. It should only be defined after the data pointer + # has been defined, and should not be hoisted into the header of + # the function as other buffer_decl statements can be. + A = T.buffer_decl([1024], dtype="float32", data=A_data) + B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B = T.buffer_decl([1024], dtype="float32", data=B_data) + B[0] = A[0] -@T.prim_func -def func_T_ptr_allocate() -> None: - A = T.allocate([1024], "float32", "global") - A[0] = 0.0 + return func_T_ptr_let_statement -def test_T_ptr_allocate(): - func = func_T_ptr_allocate - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func, True) +def func_T_ptr_allocate(): + @T.prim_func + def func_T_ptr_allocate() -> None: + A = T.allocate([1024], "float32", "global") + A[0] = 0.0 + + return func_T_ptr_allocate + + +ir_generator = tvm.testing.parameter( + opt_gemm_normalize, + opt_gemm_lower, + opt_gemm_mod_host, + opt_conv_tensorcore_normalize, + opt_conv_tensorcore_lower, + opt_conv_tensorcore_mod_host, + vthread_func, + matmul, + rank0, + rank0_block, + select, + minmax, + abs, + constant_folding, + simplify_bracket, + while_loop, + primfunc_with_allocate_annotations, + comm_reducer_single_reduce_group, + comm_reducer_multiple_reduce_groups, + multiple_commreducer, + loop_extent_dependent, + nontrivial_range_axis, + func_with_target_spec_by_config, + func_with_target_spec_by_str, + func_root_attr, + func_T_ptr_let_statement, + func_T_ptr_allocate, +) + + +def test_roundtrip(ir_generator): + original = ir_generator() + after_roundtrip = tvm.script.from_source(original.script(show_meta=True)) + tvm.ir.assert_structural_equal(original, after_roundtrip, True) if __name__ == "__main__": From d7b8f06001f7ce1d0520103cc58a4a90948d0c29 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Feb 2022 14:16:53 -0600 Subject: [PATCH 043/177] Updated TIR reference in USMP pool allocation unit tests. Using let var handles as the data pointer in buffers, rather than just as `T.load`/`T.store` arguments, requires annotation as `T.Ptr[T.primtype]`, rather than as `T.handle`. --- src/tir/usmp/analysis/extract_buffer_info.cc | 12 +-- .../convert_pool_allocations_to_offsets.cc | 99 +++++++++++++------ ...orm_convert_pool_allocations_to_offsets.py | 68 ++++++------- 3 files changed, 101 insertions(+), 78 deletions(-) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index fb4fb52c507e..2b3f93af79f2 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -73,8 +73,8 @@ class BufferInfoExtractor : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const VarNode* op) override; - void VisitExpr_(const LoadNode* op) override; - void VisitStmt_(const StoreNode* op) override; + void VisitExpr_(const BufferLoadNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void UpdateAliases(const Array& args, const PrimFunc& func); @@ -306,13 +306,13 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) { scope_stack_.pop(); } -void BufferInfoExtractor::VisitExpr_(const LoadNode* op) { - this->VisitExpr(op->buffer_var); +void BufferInfoExtractor::VisitExpr_(const BufferLoadNode* op) { + this->VisitExpr(op->buffer->data); StmtExprVisitor::VisitExpr_(op); } -void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { - this->VisitExpr(op->buffer_var); +void BufferInfoExtractor::VisitStmt_(const BufferStoreNode* op) { + this->VisitExpr(op->buffer->data); StmtExprVisitor::VisitStmt_(op); } diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 5d267d1a5363..45518d67ea0a 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -89,8 +89,8 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; - PrimExpr VisitExpr_(const LoadNode* op) override; - Stmt VisitStmt_(const StoreNode* op) override; + PrimExpr VisitExpr_(const BufferLoadNode* op) override; + Stmt VisitStmt_(const BufferStoreNode* op) override; /*! \brief This is a structure where the modified function * signature is kept while body of the function is mutated @@ -130,6 +130,10 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief Obtain a resource handle if its there */ Optional GetResourceHandle(const PrimFunc& func); + /*! \brief Get the Buffer object representing the mapped access into + * the pool. + */ + Buffer GetRemappedBuffer(Buffer buf); /*! \brief The tir::Var map to PoolInfo objects */ Map primfunc_args_to_pool_info_map_; @@ -146,7 +150,15 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded * to position from a pool as designated by a PoolAllocation */ - Map allocate_buf_to_let_var_; + Map allocate_var_to_let_var_; + /*! \brief A map from the original buffer object + * + * Each key-value pair in this map satisfies + * ``allocate_buf_to_let_var[key->data] = value->data``. However, + * since more than one `tir::Buffer` may use the same Var, they must + * be tracked separately. + */ + Map original_buf_to_let_buf_; /*! \brief A counter to give references to pools a reproducible unique set of names */ int pool_var_count_ = 0; /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */ @@ -183,12 +195,7 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda String var_name = pool_ref_name + "_var"; DataType elem_dtype = DataType::UInt(8); Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global")); - Var pool_var; - if (!emit_tvmscript_printable_) { - pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global")); - } else { - pool_var = Var(var_name, DataType::Handle(8)); - } + Var pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global")); si.params.push_back(pool_var); si.pools_to_params.Set(pool_info, pool_var); si.allocated_pool_params.push_back(AllocatedPoolInfo( @@ -258,8 +265,8 @@ Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( Array ret; for (const PrimExpr& arg : args) { if (arg->IsInstance() && - allocate_buf_to_let_var_.find(Downcast(arg)) != allocate_buf_to_let_var_.end()) { - ret.push_back(allocate_buf_to_let_var_[Downcast(arg)]); + allocate_var_to_let_var_.find(Downcast(arg)) != allocate_var_to_let_var_.end()) { + ret.push_back(allocate_var_to_let_var_[Downcast(arg)]); } else { ret.push_back(arg); } @@ -300,37 +307,65 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { PoolAllocation pool_allocation = pool_allocations_[GetRef(op)]; Var param = scope_info.pools_to_params[pool_allocation->pool_info]; Buffer buffer_var = scope_info.buffer_map[param]; - Load load_node = - Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition); - Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node}); - Var tir_var; - if (!emit_tvmscript_printable_) { - tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation); - } else { - tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8)); + BufferLoad load_node = BufferLoad(buffer_var, {pool_allocation->byte_offset}); + Call address_of_load = Call(DataType::Handle(), builtin::address_of(), {load_node}); + + Type let_var_type = op->buffer_var->type_annotation; + if (emit_tvmscript_printable_) { + // Strip the storage_scope from the variable type, as TVMScript + // doesn't parsethe scoped pointers (e.g. ``T.Ptr[global T.int32]``) + // correctly. + let_var_type = PointerType(Downcast(let_var_type)->element_type); } - allocate_buf_to_let_var_.Set(op->buffer_var, tir_var); + Var let_var(op->buffer_var->name_hint + "_let", let_var_type); + allocate_var_to_let_var_.Set(op->buffer_var, let_var); Stmt new_body = VisitStmt(op->body); - allocate_buf_to_let_var_.erase(op->buffer_var); - return LetStmt(tir_var, address_of_load, new_body); + allocate_var_to_let_var_.erase(op->buffer_var); + return LetStmt(let_var, address_of_load, new_body); } return StmtExprMutator::VisitStmt_(op); } -Stmt PoolAllocationToOffsetConverter::VisitStmt_(const StoreNode* op) { - if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { - return Store(allocate_buf_to_let_var_[op->buffer_var], VisitExpr(op->value), op->index, - VisitExpr(op->predicate)); +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + + Buffer remapped = GetRemappedBuffer(store->buffer); + if (!op->buffer.same_as(remapped)) { + store.CopyOnWrite()->buffer = remapped; } - return StmtExprMutator::VisitStmt_(op); + return std::move(store); } -PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) { - if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { - return Load(op->dtype, allocate_buf_to_let_var_[op->buffer_var], op->index, - VisitExpr(op->predicate)); +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + + Buffer remapped = GetRemappedBuffer(load->buffer); + if (!op->buffer.same_as(remapped)) { + load.CopyOnWrite()->buffer = remapped; } - return StmtExprMutator::VisitExpr_(op); + return std::move(load); +} + +Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) { + { + auto it = original_buf_to_let_buf_.find(original); + if (it != original_buf_to_let_buf_.end()) { + return (*it).second; + } + } + + Buffer remapped = original; + + auto it = allocate_var_to_let_var_.find(original->data); + if (it != allocate_var_to_let_var_.end()) { + remapped = Buffer((*it).second, original->dtype, original->shape, original->strides, + original->elem_offset, original->name, original->data_alignment, + original->offset_factor, original->buffer_type, original->axis_separators, + original->span); + } + + original_buf_to_let_buf_.Set(original, remapped); + return remapped; } IRModule PoolAllocationToOffsetConverter::operator()() { diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 1d42dade372e..404832f814a4 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -140,20 +140,20 @@ def run_model(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def run_model(input: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.handle = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") - sid_8_let: T.handle = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") + sid_9_let: T.Ptr[T.int8] = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") + sid_8_let: T.Ptr[T.int8] = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @T.prim_func - def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None: + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) @@ -170,7 +170,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] = T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16") @T.prim_func - def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16") @@ -181,7 +181,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") @@ -189,12 +189,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_7_let = T.buffer_decl([307856], "int16") + PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): PaddedInput_7_let[i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7] = T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): - Conv2dOutput_7_let = T.buffer_decl([50176], "int32") + Conv2dOutput_7_let = T.buffer_decl([64], "int32") with T.let(Conv2dOutput_7_let.data, T.address_of(fast_memory_4_buffer_var[0], dtype="handle")): for ff_3 in T.serial(0, 64): Conv2dOutput_7_let[ff_3] = 0 @@ -237,17 +237,10 @@ def test_mobilenet_subgraph(): )(tir_mod) tir_mod_with_offsets_ref = LinearStructurePlanned - tir_mod_with_offsets_ref = tvm.script.from_source( - tir_mod_with_offsets_ref.script(show_meta=False) - ) - # The TIR produced fails on roundtrip TVMScript testing. - # Therefore, indicates the TVMScript produced here and/or the parser - # is lacking functionality. Thus for these tests, uses a string - # version of the TVMScript for each function as a check instead. - for gv, func in tir_mod_with_offsets_ref.functions.items(): - assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( - tir_mod_with_offsets[gv.name_hint].script() - ) + + for gv, ref_func in tir_mod_with_offsets_ref.functions.items(): + actual_func = tir_mod_with_offsets[gv.name_hint] + tvm.ir.assert_structural_equal(actual_func, ref_func) # fmt: off @@ -374,7 +367,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @tvm.script.ir_module class ResnetStructurePlanned: @T.prim_func - def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") @@ -384,7 +377,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") @@ -397,7 +390,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3_let[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): - Conv2dOutput_3_let = T.buffer_decl([180064], 'int32') + Conv2dOutput_3_let = T.buffer_decl([64], 'int32') with T.let(Conv2dOutput_3_let.data, T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle")): for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): @@ -408,7 +401,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") @@ -431,7 +424,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") @@ -453,19 +446,19 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body - PaddedInput_1_let = T.buffer_decl([3600000], "int16") + PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1_let[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): - Conv2dOutput_1_let = T.buffer_decl([180064], "int32") + Conv2dOutput_1_let = T.buffer_decl([64], "int32") with T.let(Conv2dOutput_1_let.data, T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle")): for ff_1 in T.serial(0, 64): Conv2dOutput_1_let[ff_1] = 0 @@ -475,15 +468,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def run_model(input: T.handle, global_workspace_0_var: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output: T.handle) -> None: global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.handle = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") - sid_6_let: T.handle = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") - sid_7_let: T.handle = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") - sid_8_let: T.handle = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_2_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") + sid_6_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") + sid_7_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_8_let: T.Ptr[T.int8] = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) @@ -520,14 +513,9 @@ def test_resnet_subgraph(): tir_mod_with_offsets_ref = ResnetStructurePlanned - # The TIR produced fails on roundtrip TVMScript testing. - # Therefore, indicates the TVMScript produced here and/or the parser - # is lacking functionality. Thus for these tests, uses a string - # version of the TVMScript for each function as a check instead. - for gv, func in tir_mod_with_offsets_ref.functions.items(): - assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( - tir_mod_with_offsets[gv.name_hint].script() - ) + for gv, ref_func in tir_mod_with_offsets_ref.functions.items(): + actual_func = tir_mod_with_offsets[gv.name_hint] + tvm.ir.assert_structural_equal(actual_func, ref_func) if __name__ == "__main__": From 49eeecab3af06099efbab7836f4c083feebbe776 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 09:13:23 -0600 Subject: [PATCH 044/177] fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate --- ...st_tir_usmp_analysis_extract_bufferinfo.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 0d39e2eecf50..78f97b43c00d 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -156,9 +156,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -630,21 +630,21 @@ def run_model(input: T.handle, output: T.handle) -> None: sid_25 = T.allocate([25088], "int8", "global") sid_26 = T.allocate([25088], "int8", "global") sid_31 = T.allocate([25088], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, sid_7, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5, sid_4, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4, sid_3, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4, sid_32, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2, sid_19, sid_25, sid_31, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, sid_7.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5.data, sid_4.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4.data, sid_3.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3.data, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20.data, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3.data, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26.data, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4.data, sid_32.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32.data, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2.data, sid_19.data, sid_25.data, sid_31.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -1346,20 +1346,20 @@ def run_model(data: T.handle, output: T.handle) -> None: sid_18 = T.allocate([3456], "int8", "global.workspace") sid_19 = T.allocate([3456], "int8", "global.workspace") sid_20 = T.allocate([3456], "int8", "global.workspace") - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7, sid_6, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11, sid_10, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6, sid_10, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5, sid_4, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3, sid_2, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5, sid_20, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19, sid_18, dtype="int32")) - T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2, sid_18, output_buffer.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8.data, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7.data, sid_6.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12.data, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11.data, sid_10.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6.data, sid_10.data, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5.data, sid_4.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4.data, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3.data, sid_2.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5.data, sid_20.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20.data, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19.data, sid_18.data, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2.data, sid_18.data, output_buffer.data, dtype="int32")) # fmt: on From ff663395ce1c8d6a17d5bf361d8af0c50185bb9b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 09:15:03 -0600 Subject: [PATCH 045/177] fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate --- tests/python/unittest/test_tir_usmp_algo.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 96def166ea43..1cde5a1d6b5c 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -125,7 +125,7 @@ def test_no_pool_error(): @pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts", "hill_climb"]) def test_name_based_ordering(algorithm): - """ This checks when the size and conlicts are same a stable result is generated""" + """This checks when the size and conlicts are same a stable result is generated""" def _test(): target = Target("c") @@ -355,9 +355,9 @@ def run_model(input: T.handle, output: T.handle) -> None: T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8.data, output, dtype="int32")) __tvm_meta__ = None # fmt: on @@ -502,11 +502,11 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") - T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) - T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32")) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: From 5bebd09fd9829ac02a29dcd4fe29c4a8bc0aca4d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 09:19:34 -0600 Subject: [PATCH 046/177] fixup! Replacing all T.store TIR calls. --- .../unittest/test_tir_transform_convert_for_loops_serial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index 862fb31ee40a..d3b8fe40dbf1 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -40,8 +40,8 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T. Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))), True - T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True + Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) + T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") # fmt: on From 2856fd94f549b4ec1403dcc68f57f16d90d4b0de Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 10:09:16 -0600 Subject: [PATCH 047/177] fixup! Replacing all T.store TIR calls. --- .../test_tir_transform_compact_buffer_region.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 9a3799ba5f46..1a732a871d53 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -80,7 +80,8 @@ def unschedulable_func(a: T.handle, c: T.handle) -> None: T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - B[i * 16 + j] = A[i, j]y + 1.0 + T.evaluate(T.call_extern("dummy_extern_function", B.data, dtype="int32")) + B[i, j] = A[i, j] + 1.0 for j in range(0, 16): C[i, j] = B[i, j] * 2.0 @@ -251,7 +252,7 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(4, 8): D[k, j] = 1.0 for k in range(2, 4): - B[j] = A[i, j] + Dy[k, j] + B[i, j] = A[i, j] + D[k, j] for j in range(3, 5): with T.block() as []: T.reads(B[i, j]) @@ -281,7 +282,7 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(4, 8): D[k - 2, 0] = 1.0 for k in range(2, 4): - B[j] = A[i, j] + D[k -y 2, 0] + B[0, j] = A[i, j] + D[k - 2, 0] for j in range(3, 5): with T.block() as []: T.reads(B[0, j]) @@ -476,12 +477,14 @@ def opaque_access_annotated_func(a: T.handle) -> None: # no annotation, opaque access will cover full region T.reads([]) T.writes([]) + T.evaluate(T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32")) B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) + T.evaluate(T.call_extern("opaque_extern_function", B.data, C.data, dtype="int32")) C[i] = B[i] @@ -496,12 +499,14 @@ def compacted_opaque_access_annotated_func(a: T.handle) -> None: # no annotation, opaque access will cover full region T.reads([]) T.writes([]) + T.evaluate(T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32")) B[i] = A[i] with T.block(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) T.writes([C[i : i + 9]]) + T.evaluate(T.call_extern("opaque_extern_function", B.data, C.data, dtype="int32")) C[i] = B[i] From 4d19cc70a77401986325f0603c7dd49a4543ed29 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 11:07:46 -0600 Subject: [PATCH 048/177] fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate --- tests/python/unittest/test_tir_ptx_mma.py | 102 +++++++++++----------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 4b8e3fcaffef..2c6b80446f9b 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -52,11 +52,11 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): "fp64", "fp64", "fp64", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float64", @@ -132,11 +132,11 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): "fp16", "fp16", "fp16", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float16", @@ -213,11 +213,11 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): "fp16", "fp16", "fp32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float32", @@ -294,11 +294,11 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "int8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -368,11 +368,11 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "uint8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -442,11 +442,11 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "int4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -508,11 +508,11 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "uint4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -578,11 +578,11 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle) "fp16", "fp16", "fp32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float32", @@ -658,11 +658,11 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle "fp16", "fp16", "fp16", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float16", @@ -740,11 +740,11 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle "fp16", "fp16", "fp32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="float32", @@ -822,11 +822,11 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "int8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -904,11 +904,11 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "uint8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -986,11 +986,11 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "int8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1068,11 +1068,11 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): "int8", "uint8", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1150,11 +1150,11 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "int4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1224,11 +1224,11 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): "int4", "uint4", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", @@ -1298,11 +1298,11 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): "int1", "int1", "int32", - MultiA, + MultiA.data, 0, - MultiB, + MultiB.data, 0, - Accum, + Accum.data, 0, False, dtype="int32", From 00292c36640f46f5c1006f37cdb07badb9263914 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 11:08:05 -0600 Subject: [PATCH 049/177] fixup! In test directory, replacing all instances of T.load. --- tests/python/unittest/test_tir_ptx_mma.py | 48 +++++++++-------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 2c6b80446f9b..c0322d38c29c 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -63,9 +63,7 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "float64", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -146,7 +144,7 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): C[ ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), mma_accum_c_id % 4 + (4 * ((tx % 32) % 16 // 8)) + mma_accum_c_id // 4 * 8, - ] = T.load("float16", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -233,7 +231,7 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): + (tx % 32) % 16 // 8 * 4 + mma_accum_c_id % 2 + mma_accum_c_id // 4 * 8, - ] = T.load("float32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -305,9 +303,7 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -379,9 +375,7 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -453,9 +447,7 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -519,9 +511,7 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): ) ) for mma_accum_c_id in range(2): - C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( - "int32", Accum, mma_accum_c_id - ) + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -589,9 +579,9 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle) ) ) for mma_accum_c_id in range(4): - C[ - (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2 - ] = T.load("float32", Accum, mma_accum_c_id) + C[(tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2] = Accum[ + mma_accum_c_id + ] @tvm.testing.requires_cuda @@ -672,7 +662,7 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("float16", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -754,7 +744,7 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("float32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -836,7 +826,7 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -918,7 +908,7 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1000,7 +990,7 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1082,7 +1072,7 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1164,7 +1154,7 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1238,7 +1228,7 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda @@ -1312,7 +1302,7 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, - ] = T.load("int32", Accum, mma_accum_c_id) + ] = Accum[mma_accum_c_id] @tvm.testing.requires_cuda From 029496e7b4d9042b35414cc58ad91616f5110848 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 13:37:22 -0600 Subject: [PATCH 050/177] tir.ComputeInline, correct variable count. Previously, this metaschedule primitive relied on `tir::UndefinedVars` ignoring the data pointer of BufferLoad/BufferStore nodes. When `tir::UndefinedVars` was updated to visit the data pointer, similar to the previous behavior when visiting Load/Store nodes, this caused the count of undefined variables to be unexpectedly high. --- src/tir/schedule/primitive/compute_inline.cc | 30 ++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 6dccfc311cd4..d7556ed73995 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -284,6 +284,31 @@ class BaseInliner : public StmtExprMutator { } } + /*! + * \brief Count the number of undefined variables that are not used + * as buffer objects. + * + * This is used to determine whether inlining or reverse inlining is + * possible. The only undefined variables present should be the + * load/store indices, or buffer access based on those indices. + * + * \param stmt The statement in which to count undefined variables + */ + static int GetNumUndefinedNonpointerVars(const Stmt& stmt) { + auto undefined_vars = UndefinedVars(stmt, {}); + // Buffer pointers and the inlined indices are allowed, but no + // other variables may appear in the inlined block. + int num_nonpointer_vars = 0; + for (const auto& var : undefined_vars) { + bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() && + var->type_annotation.as(); + if (!is_pointer) { + num_nonpointer_vars++; + } + } + return num_nonpointer_vars; + } + private: /*! * \brief Add the buffers in the block signature to the `buffer_var_map_`, @@ -417,7 +442,8 @@ class ComputeInliner : public BaseInliner { if (inlined_store_ == nullptr) { return false; } - int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + + int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { return false; } @@ -484,7 +510,7 @@ class ReverseComputeInliner : public BaseInliner { // Failure: no BufferLoad from the `inlined_buffer_` return false; } - int n_vars = UndefinedVars(GetRef(inlined_store_), {}).size(); + int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); for (const BufferLoadNode* load : loads) { if (!UpdateAndCheckIndexVars(load->indices, n_vars)) { // Failure: incorrect of inconsistent index vars From 3676f80b8939d30865366723e24f0f74e287a921 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 14:03:30 -0600 Subject: [PATCH 051/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_nodes.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index fe719ee99693..88ad568691a2 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -16,7 +16,7 @@ # under the License. import pytest import tvm -from tvm import te +from tvm import te, ir import numpy as np @@ -84,11 +84,18 @@ def test_ir(): def test_ir2(): + buf_size = te.var("size") x = te.var("n") - a = te.var("array", "handle") - st = tvm.tir.Store(a, x + 1, 1) - assert isinstance(st, tvm.tir.Store) - assert st.buffer_var == a + + storage_type = ir.PrimType("int32") + handle_type = ir.PointerType(storage_type) + array = te.var("array", handle_type) + buf = tvm.tir.decl_buffer([buf_size], "int32", data=array) + + st = tvm.tir.BufferStore(buf, x + 1, [1]) + assert isinstance(st, tvm.tir.BufferStore) + assert st.buffer == buf + assert st.buffer.data == array def test_let(): From cafadcca8bf259fd893e6d27adc01541f2ae9813 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 14:21:04 -0600 Subject: [PATCH 052/177] fixup! Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects. --- src/tir/transforms/lower_match_buffer.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index f956c09f0457..5bde5cb90e2b 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -177,7 +177,7 @@ class MatchBufferLower : public StmtExprMutator { Bind(buffer->data, source_buffer->data, buffer->name + ".data"); // Step.2.2. Update element offset - // Note we create Load via vload and try to reuse index calculate. + // We use the ElemOffset method to avoid duplicating the index calculation. { Array indices; indices.reserve(source->region.size()); @@ -185,11 +185,11 @@ class MatchBufferLower : public StmtExprMutator { indices.push_back(range->min); } - auto load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); - if (load->indices.size() == 1) { - Bind(buffer->elem_offset, load->indices[0], buffer->name + ".elem_offset"); + Array buffer_start_indices = source_buffer->ElemOffset(indices); + if (buffer_start_indices.size() == 1) { + Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) - << "The source elem_offset " << load->indices[0] + << "The source elem_offset " << buffer_start_indices[0] << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { // Non-zero elem_offset is ill-defined for non-flat memory. From 04995f676448a98eaebcc74ec3c6f17114580ff7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 14:24:11 -0600 Subject: [PATCH 053/177] fixup! In test directory, replacing all instances of T.load. --- tests/python/unittest/test_tir_ir_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 5b123e883849..9438da17ede2 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -56,8 +56,8 @@ def test_if(): body = body.body assert isinstance(body, tvm.tir.IfThenElse) assert isinstance(body.condition, tvm.tir.EQ) - assert isinstance(body.then_case.index, tvm.tir.Var) - assert body.else_case.index.value == 0 + assert isinstance(body.then_case.indices[0], tvm.tir.Var) + assert list(body.else_case.indices) == [0] def test_prefetch(): From 9a22d145088e6559f284274e42c679a85f4c49e0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 14:47:49 -0600 Subject: [PATCH 054/177] fixup! In test directory, replacing all instances of T.load. --- tests/python/unittest/test_tir_constructor.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 00aba46ba431..339320d53120 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -87,13 +87,14 @@ def test_expr_constructor(): assert x.false_value == b assert x.condition == a - buffer_var = te.var("x", dtype="handle") - x = tvm.tir.Load("float32", buffer_var, 1, a) - assert isinstance(x, tvm.tir.Load) + buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) + buffer = tvm.tir.decl_buffer([16], "float32", data=buffer_var) + x = tvm.tir.BufferLoad(buffer, [1]) + assert isinstance(x, tvm.tir.BufferLoad) assert x.dtype == "float32" - assert x.buffer_var == buffer_var - assert x.index.value == 1 - assert x.predicate == a + assert x.buffer == buffer + assert x.buffer.data == buffer_var + assert list(x.indices) == [1] x = tvm.tir.Ramp(1, 2, 10) assert isinstance(x, tvm.tir.Ramp) From e1a16b748bc009f18a4e5d72f65cb0df30cd8658 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 14:47:57 -0600 Subject: [PATCH 055/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_constructor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 339320d53120..dcd642c3b9ec 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -127,7 +127,6 @@ def test_expr_constructor(): def test_stmt_constructor(): v = te.var("aa") - buffer_var = te.var("buf", dtype="handle") nop = tvm.tir.Evaluate(1) x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1)) assert isinstance(x, tvm.tir.LetStmt) @@ -149,10 +148,13 @@ def test_stmt_constructor(): assert x.extent.value == 10 assert x.body == nop - x = tvm.tir.Store(buffer_var, 1, 10, tvm.tir.const(1, "uint1")) - assert isinstance(x, tvm.tir.Store) - assert x.buffer_var == buffer_var - assert x.index.value == 10 + buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1"))) + buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var) + x = tvm.tir.BufferStore(buffer, 1, [10]) + assert isinstance(x, tvm.tir.BufferStore) + assert x.buffer == buffer + assert x.buffer.data == buffer_var + assert list(x.indices) == [10] assert x.value.value == 1 buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) From c69d569c1d66c9931e7f6351e96aa40725a8e9de Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 15:19:18 -0600 Subject: [PATCH 056/177] Expose Buffer index flattening function to Python. --- include/tvm/tir/buffer.h | 8 ++++++++ python/tvm/tir/buffer.py | 17 +++++++++++++++++ src/tir/ir/buffer.cc | 6 ++++++ 3 files changed, 31 insertions(+) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 1140648fd41b..69d6777d87f1 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -214,6 +214,14 @@ class Buffer : public ObjectRef { */ Buffer GetFlattenedBuffer() const; + /*! \brief Determine the offset in the buffer of the given index. + * + * Returns the buffer offset, in number of elements of type dtype, + * without adjusting for number of lanes. (e.g. The number of + * float16x4 elements in a buffer of type float16x4.) + */ + Array OffsetOf(Array index) const; + /*! * \brief Return the storage scope associated with this buffer. */ diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index d60f6185d0e6..53c0916e599f 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -153,6 +153,23 @@ def get_flattened_buffer(self): """ return _ffi_api.BufferGetFlattenedBuffer(self) # type: ignore + def offset_of(self, indices): + """Determine the offset of the provided indices in the flattened buffer. + + Params + ------- + indices : Union[PrimExpr, List[PrimExpr]] + + The indices of the element in the original buffer. + + Returns + ------- + flattened_indices: List[PrimExpr] + + The offset indices of the element in the flattened buffer. + """ + return _ffi_api.BufferOffsetOf(self, indices) # type: ignore + def decl_buffer( shape, diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index f192d6bd11c9..828433f72e2e 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -243,6 +243,10 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { return no_opt_sum; } +Array Buffer::OffsetOf(Array input_indices) const { + return (*this)->ElemOffset(std::move(input_indices)); +} + // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. @@ -577,6 +581,8 @@ TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer); +TVM_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); + TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); From b5644f212c3211b82094fe46bebf1a12a7be82c1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 15:20:42 -0600 Subject: [PATCH 057/177] Updated test_tir_buffer.py offset tests. Replacing calls to `Buffer.vload` with `Buffer.offset_of`, when testing the index calculations. --- tests/python/unittest/test_tir_buffer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 422d730160b5..5eac95ad77aa 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -124,32 +124,32 @@ def assert_simplified_equal(index_simplified, index_direct): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod # Test Case1 - index_simplified = A_stride.vload( + index_simplified = A_stride.offset_of( (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1) ) - index_direct = A_stride.vload((0, k0)) + index_direct = A_stride.offset_of((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case2 - index_simplified = A.vload( + index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)) ) - index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) + index_direct = A.offset_of((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 - index_simplified = A.vload( + index_simplified = A.offset_of( ( idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxd(idxm(k0, idxd(k1, s)), n), idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxm(idxm(k0, idxd(k1, s)), n), ) ) - index_direct = A.vload((0, k0)) + index_direct = A.offset_of((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case4 (not able to simplify) - index_simplified = A.vload( + index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)) ) - index_direct = A.vload( + index_direct = A.offset_of( (0, idxd(idxm(k0, idxd(k1, s)), n) * n + (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))) ) assert_simplified_equal(index_simplified, index_direct) @@ -160,7 +160,7 @@ def assert_simplified_equal(index_simplified, index_direct): j = te.size_var("j") k = te.size_var("k") - index_simplified = B.vload( + index_simplified = B.offset_of( ( idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), @@ -168,7 +168,7 @@ def assert_simplified_equal(index_simplified, index_direct): idxm((i * 50176 + j * 28672 + k), 1024), ) ) - index_direct = B.vload((0, 0, 0, (i * 50176 + j * 28672 + k))) + index_direct = B.offset_of((0, 0, 0, (i * 50176 + j * 28672 + k))) assert_simplified_equal(index_simplified, index_direct) From 777b3799b1396d44ee8328a3270de13521816a16 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 15:26:23 -0600 Subject: [PATCH 058/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_transform_vectorize.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index b1e580957b24..14f33fb9549c 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -35,7 +35,8 @@ def test_vectorize_loop(): assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) - assert isinstance(stmt.body.index, tvm.tir.Ramp) + assert len(stmt.body.indices) == 1 + assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) assert isinstance(stmt.body.value, tvm.tir.Broadcast) @@ -55,7 +56,8 @@ def test_vectorize_vector(): assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) - assert isinstance(stmt.body.index, tvm.tir.Ramp) + assert len(stmt.body.indices) == 1 + assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) assert isinstance(stmt.body.value, tvm.tir.Broadcast) @@ -76,7 +78,8 @@ def test_vectorize_with_if(): stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body assert isinstance(stmt, tvm.tir.IfThenElse) - assert isinstance(stmt.then_case.index, tvm.tir.Ramp) + assert len(stmt.then_case.indices) == 1 + assert isinstance(stmt.then_case.indices[0], tvm.tir.Ramp) assert isinstance(stmt.then_case.value, tvm.tir.Add) assert stmt.then_case.value.dtype == "float32x4" assert isinstance(stmt.else_case, tvm.tir.For) From fc13666a991889bd4c59a720de2a6668a7f7aa40 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 15:28:46 -0600 Subject: [PATCH 059/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_transform_unroll_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index b511118f8b52..7989fba2d29a 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -90,7 +90,7 @@ def test_unroll_fake_loop(): } ): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret[0], tvm.tir.Store) + assert isinstance(ret[0], tvm.tir.BufferStore) def test_unroll_single_count_loops(): From 391a28bf4390acc4ae3be1a01b906da293667ddf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Feb 2022 15:55:34 -0600 Subject: [PATCH 060/177] fixup! Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects. --- src/tir/ir/buffer.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 828433f72e2e..1e470f68ada1 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -252,7 +252,9 @@ Array Buffer::OffsetOf(Array input_indices) const { // We also perform optimization to simplify the indexing expression. Array BufferNode::ElemOffset(Array input_indices) const { ICHECK_EQ(shape.size(), input_indices.size()) - << "Dimensionality of buffer must match dimensionality of index used to access it"; + << "Buffer " << this->name << " is " << shape.size() + << "-dimensional, cannot be indexed with the " << input_indices.size() + << "-dimensional indices provided."; if (strides.size()) { ICHECK_EQ(this->strides.size(), input_indices.size()) From 6c1f86167c96f912cf8e5643b9f5460aa154d89b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Feb 2022 14:52:41 -0600 Subject: [PATCH 061/177] fixup! Replacing Store/Load in lowering/legalization passes. --- src/tir/transforms/inject_double_buffer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index d39538c0faf0..c99d23ac8e50 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -103,8 +103,6 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); const VarNode* buf = op->buffer_var.as(); auto it = dbuffer_info_.find(buf); @@ -115,6 +113,8 @@ class DoubleBufferInjector : public StmtExprMutator { << "Has StorageFlatten (TE-based schedules) or " << "FlattenBuffer (TIR-based schedules) been run?"; it->second.stride = op->extents[0]; + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; ICHECK(it->second.loop != nullptr); @@ -123,7 +123,7 @@ class DoubleBufferInjector : public StmtExprMutator { Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; } else { - return stmt; + return StmtExprMutator::VisitStmt_(op); } } From 790708378e272b25e099f830fa8daeafd2ae033b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Feb 2022 14:54:52 -0600 Subject: [PATCH 062/177] fixup! Replacing all T.store TIR calls. --- .../test_tir_transform_storage_flatten.py | 47 +++++++++---------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index a51e926155d3..8e430b035606 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -78,28 +78,25 @@ def test_flatten_storage_align(): def test_flatten_double_buffer(): - dtype = "int64" - n = 100 - m = 4 - tx = te.thread_axis("threadIdx.x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - ib.scope_attr(tx, "thread_extent", 1) - with ib.for_range(0, n) as i: - B = ib.allocate("float32", m, name="B", scope="shared") - with ib.new_scope(): - ib.scope_attr(B.asobject(), "double_buffer_scope", 1) - with ib.for_range(0, m) as j: - B[j] = A[i * 4 + j] - with ib.for_range(0, m) as j: - C[j] = B[j] + 1 - - stmt = ib.get() - - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([A, C], stmt).with_attr("from_legacy_te_schedule", True) - ) + @tvm.script.ir_module + class ModFromScript: + @T.prim_func + def main(A_param: T.handle, C_param: T.handle): + A = T.match_buffer(A_param, (400,), "float32", strides=[1]) + C = T.match_buffer(C_param, (4,), "float32", strides=[1]) + T.func_attr({"from_legacy_te_schedule": True}) + threadIdx_x = T.env_thread("threadIdx.x") + T.launch_thread(threadIdx_x, 1) + for i in T.serial(0, 100): + B = T.allocate([4], "float32", scope="shared", strides=[1]) + with T.attr(B.data, "double_buffer_scope", 1): + for j in T.serial(0, 4): + B[j] = A[4 * i + j] + + for j in T.serial(0, 4): + C[j] = B[j] + 1.0 + + mod = ModFromScript with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}): mod = tvm.transform.Sequential( @@ -112,10 +109,10 @@ def test_flatten_double_buffer(): stmt = mod["main"].body assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert list(stmt.body.extents) == [8] - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) - f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] + mod = tvm.tir.transform.ThreadSync("shared")(mod) + f = mod["main"] count = [0] From 293f54a45ab29b503a0956fc0228be7d8fdb34f5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Feb 2022 16:15:46 -0600 Subject: [PATCH 063/177] fixup! Updated ethos-u C++ unit tests to remove use of Load/Store. --- python/tvm/relay/backend/contrib/ethosu/tir/dma.py | 6 +++--- python/tvm/relay/backend/contrib/ethosu/tir/identity.py | 4 ++-- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 34ea9ef87c96..aa4c09f24d7c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -94,10 +94,10 @@ def get_upscale_params(stmt): _, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer_var + input_pointer = inner.value.args[1].buffer.data else: - input_pointer = inner.value.buffer_var - output_pointer = inner.buffer_var + input_pointer = inner.value.buffer.data + output_pointer = inner.buffer.data return (input_pointer, output_pointer) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index aacff55c451b..40686ac2336f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -65,7 +65,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur stride_vars = [l.loop_var for l in loops] strides = get_strides(fm_inner.indices[0], stride_vars) - base_address = [get_base_address(index) for index in fm_inner.index] + base_address = [get_base_address(index) for index in fm_inner.indices] data_type = inner.buffer.data.type_annotation.element_type.dtype serial_feature_map = SerialFeatureMap( @@ -76,7 +76,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur tile_height_0=loops[0].extent, tile_height_1=0, tile_width_0=loops[1].extent if len(loops) > 1 else 1, - tile_address_0=tvm.tir.BufferLoad(fm_inner, base_address), + tile_address_0=tvm.tir.BufferLoad(fm_inner.buffer, base_address), tile_address_1=0, tile_address_2=0, tile_address_3=0, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 62a2e01f37e8..6f977850abc2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -247,7 +247,7 @@ def _visit(stmt): new_consts.append(const[offset : offset + length]) new_buffer = tvm.tir.decl_buffer((length,), arg.dtype) new_buffers.append(new_buffer) - new_args.append(tvm.tir.expr.BufferLoad(new_buffer.data, [0])) + new_args.append(tvm.tir.expr.BufferLoad(new_buffer, [0])) continue keep_buffers.add(arg.buffer.data) @@ -735,7 +735,7 @@ def _ftransform(f, mod, ctx): if i not in const_dict.keys(): new_params.append(f.params[i]) new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] - return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span) + return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.preflattened_buffer_map, f.attrs, f.span) def _create_primfunc_without_constants(mod): transform_func = tvm.tir.transform.prim_func_pass( From e38bec0838f5a9851b300c09467d83e94bd85147 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 08:38:37 -0600 Subject: [PATCH 064/177] fixup! Replacing Store/Load in lowering/legalization passes. Fix linting for inject_double_buffer.cc --- src/tir/transforms/inject_double_buffer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index c99d23ac8e50..03f2ccd40dd1 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -103,7 +103,6 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - const VarNode* buf = op->buffer_var.as(); auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { From ed21da0bbb4d5d23581d920850134a09f290eb87 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 08:40:19 -0600 Subject: [PATCH 065/177] fixup! Updated ethos-u C++ unit tests to remove use of Load/Store. --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 6f977850abc2..a4c873d19d75 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -735,7 +735,15 @@ def _ftransform(f, mod, ctx): if i not in const_dict.keys(): new_params.append(f.params[i]) new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] - return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.preflattened_buffer_map, f.attrs, f.span) + return tvm.tir.PrimFunc( + new_params, + f.body, + f.ret_type, + new_buffer_map, + f.preflattened_buffer_map, + f.attrs, + f.span, + ) def _create_primfunc_without_constants(mod): transform_func = tvm.tir.transform.prim_func_pass( From 1f593d402ef4096c333c5163a7b12c38ba94263a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 08:42:06 -0600 Subject: [PATCH 066/177] fixup! Added .astype to tvm.script.tir.node.BufferSlice --- python/tvm/script/tir/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index 42c8754ded95..9cdd0f910cb6 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -155,4 +155,4 @@ def asobject(self) -> BufferLoad: return BufferLoad(self.buffer, indices, span=self.span) def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: - return self.asobject().astype(dtype) + return self.asobject().astype(dtype, span) From cf5555aa018d5751a1c0e54d2d3d6d48470e551c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 08:53:38 -0600 Subject: [PATCH 067/177] fixup! In test directory, replacing all instances of T.load. --- .../test_ethosu/test_remove_concatenates.py | 24 ++++++++-------- .../test_ethosu/test_replace_conv2d.py | 28 +++++++++---------- .../contrib/test_ethosu/test_scheduler.py | 4 +-- .../contrib/test_ethosu/test_vela_api.py | 16 +++++------ 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index f6e0e2d855cd..d9b08d521be5 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -33,20 +33,20 @@ class ReferenceModule: def main(placeholder: T.Buffer[(1, 8, 12, 16), "int8"], placeholder_1: T.Buffer[(1, 8, 10, 16), "int8"], T_concat: T.Buffer[(1, 8, 32, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") - buffer_4 = T.buffer_var("uint8", "") - buffer_5 = T.buffer_var("uint8", "") - buffer_6 = T.buffer_var("uint8", "") - buffer_7 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") + buffer_4 = T.buffer_decl([], "uint8") + buffer_5 = T.buffer_decl([], "uint8") + buffer_6 = T.buffer_decl([], "uint8") + buffer_7 = T.buffer_decl([], "uint8") # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", placeholder_1.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 2992, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat.data, 352), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 2992, 12, T.load("uint8", buffer_3, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_4, 0), 2992, 12, T.load("uint8", buffer_5, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_6, 0), 2992, 12, T.load("uint8", buffer_7, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, 12, buffer_3[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, 12, buffer_5[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T_concat_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T_concat[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_6[0], 2992, 12, buffer_7[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index cc11e21e94e0..d9a31a9b86f3 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -414,16 +414,16 @@ class Conv2dDoubleCascade5: def main(placeholder: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write: T.Buffer[(1, 32, 32, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, 12, T.load("uint8", buffer_1, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, 12, T.load("uint8", buffer_1, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 4096), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[4096], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) __tvm_meta__ = None @@ -433,14 +433,14 @@ class Conv2dDoubleCascade6: def main(placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write: T.Buffer[(1, 32, 2, 32, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_var("uint8", "") - buffer_1 = T.buffer_var("uint8", "") - buffer_2 = T.buffer_var("uint8", "") - buffer_3 = T.buffer_var("uint8", "") + buffer = T.buffer_decl([], "uint8") + buffer_1 = T.buffer_decl([], "uint8") + buffer_2 = T.buffer_decl([], "uint8") + buffer_3 = T.buffer_decl([], "uint8") # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 11040, 12, T.load("uint8", buffer_3, 0), 272, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, 12, buffer_3[0], 272, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index ab2e3942582e..57864218aab6 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -194,11 +194,11 @@ def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffe T.evaluate(T.call_extern("ethosu_copy", weight_buffer[0], 2608, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", bias_buffer[0], 240, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer.data[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 2608, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 2608, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", weight_buffer2[0], 736, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", bias_buffer2[0], 240, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 736, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, output_buffer.data[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, output_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index be75ff9b827a..5e4aaad304a8 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -72,7 +72,7 @@ def main( 8, 0, 8, - placeholder_3.data[0], + placeholder_3[0], 0, 0, 0, @@ -89,7 +89,7 @@ def main( 8, 0, 8, - ethosu_conv2d_1.data[0], + ethosu_conv2d_1[0], 0, 0, 0, @@ -105,10 +105,10 @@ def main( 1, 1, 1, - placeholder_4.data[0], + placeholder_4[0], 0, 12, - placeholder_5.data[0], + placeholder_5[0], 0, 0, 0, @@ -168,7 +168,7 @@ def main( 8, 0, 8, - placeholder_3.data[0], + placeholder_3[0], 0, 0, 0, @@ -185,7 +185,7 @@ def main( 8, 0, 8, - ethosu_conv2d_1.data[0], + ethosu_conv2d_1[0], 0, 0, 0, @@ -201,10 +201,10 @@ def main( 1, 1, 1, - placeholder_4.data[0], + placeholder_4[0], 0, 12, - placeholder_5.data[0], + placeholder_5[0], 0, 0, 0, From 2119f0be9328ffd09e2074a3aa5881953a4990a8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 08:57:49 -0600 Subject: [PATCH 068/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_transform_simplify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index f298288fee9e..824bef4f32f9 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -30,7 +30,7 @@ def test_stmt_simplify(): body = tvm.tir.LetStmt(n, 10, ib.get()) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) body = tvm.tir.transform.Simplify()(mod)["main"].body - assert isinstance(body.body, tvm.tir.Store) + assert isinstance(body.body, tvm.tir.BufferStore) def test_thread_extent_simplify(): @@ -48,7 +48,7 @@ def test_thread_extent_simplify(): body = tvm.tir.LetStmt(n, 10, ib.get()) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C, n], body)) body = tvm.tir.transform.Simplify()(mod)["main"].body - assert isinstance(body.body.body.body, tvm.tir.Store) + assert isinstance(body.body.body.body, tvm.tir.BufferStore) def test_if_likely(): From 4877be3a20796685c8de3ef24c571ca5ee7bae98 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 09:02:35 -0600 Subject: [PATCH 069/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_transform_remove_no_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 8b7a16952af9..e80d46193507 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -54,7 +54,7 @@ def test_remove_no_op(): ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body assert isinstance(ret, tvm.tir.Evaluate) - store = tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1) + store = tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]) stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])]) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2)) From a1e8ed41376294b68fb7e62d74bf18d17499b53e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 09:05:40 -0600 Subject: [PATCH 070/177] fixup! In test directory, replacing all instances of T.load. --- tests/python/unittest/test_tir_transform_narrow_datatype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 9b95266d3287..c1d659ca1c0f 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -239,7 +239,7 @@ def check(shape, index, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) - assert stmt.value.index.dtype == target_dtype + assert stmt.value.indices[0].dtype == target_dtype check( (const(2 ** 16, "int64"), const(2 ** 15 + 1, "int64")), From 2f461522a8dddfa98d1a366e108a820cb8f162b7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 09:50:33 -0600 Subject: [PATCH 071/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_tir_transform_lower_tvm_builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 63772dea65d7..76d6bb82cce3 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -157,8 +157,8 @@ def build_tir(): Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32")) # return handle # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1] - Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject())) - ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1)) + Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject().data)) + ib.emit(tvm.tir.BufferStore(Aptr, tvm.tir.const(expected_value[1], "float32"), [1])) stmt = ib.get() return tvm.IRModule.from_expr( From e6579b2b0d09a5bbc77453e77de428218abc011f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 09:51:08 -0600 Subject: [PATCH 072/177] fixup! Replacing Store/Load in lowering/legalization passes. --- src/tir/transforms/lower_tvm_builtin.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 00971f7c3a98..d8c7afe92255 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -181,6 +181,10 @@ class BuiltinLower : public StmtExprMutator { stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", max_sizes.shape_stack), stmt); } + if (max_sizes.array_stack != 0) { + stmt = LetStmt(scope.stack_array, StackAlloca("array", max_sizes.array_stack), stmt); + } + if (max_sizes.arg_stack != 0) { scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), max_sizes.arg_stack)}, DataType::Int(32), "stack_tcode"); @@ -189,10 +193,6 @@ class BuiltinLower : public StmtExprMutator { stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", max_sizes.arg_stack), stmt); } - if (max_sizes.array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", max_sizes.array_stack), stmt); - } - // Copy these values from the earlier search, for use in bounds // checks. scope.max_shape_stack = max_sizes.shape_stack; From 5601423e5d777bdfa9eebb479c2203b32d447af0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 09:59:19 -0600 Subject: [PATCH 073/177] [UnitTests] Added T.preflattened_buffer in expected result --- .../unittest/test_tir_transform_loop_partition.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 53011975ca21..dec847aec1b6 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -539,11 +539,13 @@ def test_simple_rfactor(): @T.prim_func -def partitioned_concat(a: T.handle, b: T.handle, c: T.handle) -> None: +def partitioned_concat( + A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"], C: T.Buffer[(32,), "float32"] +) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16], dtype="float32") - B = T.match_buffer(b, [16], dtype="float32") - C = T.match_buffer(c, [32], dtype="float32") + T.preflattened_buffer(A, [16], data=A.data) + T.preflattened_buffer(B, [16], data=B.data) + T.preflattened_buffer(C, [32], data=C.data) for i in T.serial(0, 16): C[i] = A[i] for i in T.serial(0, 16): From 40eaef41a05bfe3e779aca1fcbcb786c75aa21bf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 10:07:42 -0600 Subject: [PATCH 074/177] fixup! In test directory, replacing all instances of T.load. --- tests/python/unittest/test_tir_transform_ir_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_ir_utils.py b/tests/python/unittest/test_tir_transform_ir_utils.py index b6752ee3efd3..8030b77f9946 100644 --- a/tests/python/unittest/test_tir_transform_ir_utils.py +++ b/tests/python/unittest/test_tir_transform_ir_utils.py @@ -16,15 +16,18 @@ # under the License. import pytest import tvm -from tvm import tir +from tvm import tir, ir def test_convert_ssa(): + dtype = "int32" zero = tir.const(0) nop = tir.Evaluate(zero) - v = tir.Var("i1", "int32") + var_type = ir.PointerType(ir.PrimType(dtype)) + v = tir.Var("i1", var_type) + buf = tir.decl_buffer([16], dtype=dtype, data=v) for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) - load = tir.Evaluate(tir.Load("int32", v, zero)) + load = tir.Evaluate(tir.BufferLoad(buf, [zero])) seq = tir.SeqStmt([for_stmt, for_stmt, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) From 5cad4827693b28c1f03953ce717ee8bcdb6f8164 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 11:15:04 -0600 Subject: [PATCH 075/177] [UnitTests] Bound checker update, compare against N-d buffer bounds. --- .../unittest/test_tir_transform_instrument_bound_checkers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index 2c9997f6fe78..ba83ef90f616 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -161,7 +161,7 @@ def check_attr_stmt(x): if ( isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" - and str(x.value) == str(n) + and tvm.ir.structural_equal(x.value.args, [n]) ): return True return False From 4775f897469000799c8af3439aa70148abe51877 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 13:28:44 -0600 Subject: [PATCH 076/177] Fixup, bound checker vectorize test. --- .../test_tir_transform_instrument_bound_checkers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index ba83ef90f616..279faa54d830 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -94,7 +94,7 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): @tvm.testing.requires_llvm def test_in_bounds_vectorize_llvm(): n = 512 - lanes = 2 + lanes = 1 A = te.placeholder((n,), name="A", dtype="float32x%d" % lanes) B = te.compute((n,), lambda i: A[i], name="B") C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C") @@ -111,7 +111,9 @@ def test_in_bounds_vectorize_llvm(): f = tvm.build(s, [A, C], "llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.empty((n,), A.dtype).copyfrom(np.random.uniform(size=(n, lanes))) + a = tvm.nd.empty((n,), A.dtype).copyfrom( + np.random.uniform(size=[n] + ([] if lanes == 1 else [lanes])) + ) c = tvm.nd.empty((n,), C.dtype, dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) From f99041550ae09d1b1e9fa6826f68e5cba3463b64 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 13:54:23 -0600 Subject: [PATCH 077/177] fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate --- ...est_tir_transform_inject_virtual_thread.py | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 673267a9b1fa..1d13acce369a 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -17,8 +17,10 @@ import tvm from tvm import te +vthread_name = tvm.testing.parameter("vthread", "cthread") -def test_vthread(): + +def test_vthread(vthread_name): dtype = "int64" n = 100 m = 4 @@ -35,7 +37,7 @@ def get_vthread(name): ib.scope_attr(ty, "virtual_thread", nthread) B = ib.allocate("float32", m, name="B", scope="shared") B[i] = A[i * nthread + tx] - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) + bbuffer = B.asobject() ib.emit( tvm.tir.call_extern( "int32", @@ -47,20 +49,19 @@ def get_vthread(name): C[i * nthread + tx] = B[i] + 1 return ib.get() - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"] - - assert stmt.body.body.extents[0].value == 2 + if vthread_name == "vthread": + B_expected_alloc = m * nthread + elif vthread_name == "cthread": + B_expected_alloc = m * nthread * nthread stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) + tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread(vthread_name))) )["main"] - assert len(stmt.body.body.extents) == 3 + assert list(stmt.body.body.extents) == [B_expected_alloc] -def test_vthread_extern(): +def test_vthread_extern(vthread_name): dtype = "int64" n = 100 m = 4 @@ -76,9 +77,9 @@ def get_vthread(name): A = ib.allocate("float32", m, name="A", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared") C = ib.allocate("float32", m, name="C", scope="shared") - cbuffer = tvm.tir.decl_buffer((m,), dtype=C.dtype, data=C.asobject()) - abuffer = tvm.tir.decl_buffer((m,), dtype=A.dtype, data=A.asobject()) - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) + abuffer = A.asobject() + bbuffer = B.asobject() + cbuffer = C.asobject() A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit( @@ -92,13 +93,19 @@ def get_vthread(name): ) return ib.get() + if vthread_name == "vthread": + A_expected_alloc = m * nthread + elif vthread_name == "cthread": + A_expected_alloc = m * nthread * nthread + + C_expected_alloc = m * nthread * nthread + stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) + tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread(vthread_name))) )["main"] - assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.extents) == 3 + assert list(stmt.body.body.extents) == [A_expected_alloc] + assert list(stmt.body.body.body.body.extents) == [C_expected_alloc] def test_vthread_if_then_else(): From b26a2cda1b20e94b849d725583c06f786ad4247d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 14:53:31 -0600 Subject: [PATCH 078/177] [UnitTest] Fixed breakage in InjectRollingBuffer test. Needed a bit more re-writing than usual, because the test was explicitly calling lowering passes, then calling `tvm.build`. Fixed by using the standard lowering flow, with preprocessing steps inserting with `tir.add_lower_pass`. --- ...est_tir_transform_inject_rolling_buffer.py | 62 ++++++++++--------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py index 2298fe94da18..4f70639eada9 100644 --- a/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_rolling_buffer.py @@ -37,33 +37,41 @@ def _tile_nd(s, tensor, tile): return outer_indices, inner_indices -def _lower_schedule(sch, args): - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) +@tvm.tir.transform.prim_func_pass(opt_level=0) +def remove_rolling_buffer_attr(func, mod, ctx): + def unwrap(node): + if isinstance(node, tvm.tir.AttrStmt) and node.attr_key == "rolling_buffer_scope": + return node.body + else: + return node + + return func.with_body( + tvm.tir.stmt_functor.ir_transform( + func.body, None, postorder=unwrap, only_enable=["tir.AttrStmt"] + ) + ) - compact = tvm.te.schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, None) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - func = func.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({"main": func}) - return mod +@tvm.tir.transform.prim_func_pass(opt_level=0) +def verify_no_rolling_buffer_attr(func, mod, ctx): + def verify(node): + if isinstance(node, tvm.tir.AttrStmt): + assert node.attr_key != "rolling_buffer_scope", "Failed to lower rolling buffers" + tvm.tir.stmt_functor.post_order_visit(func.body, verify) -def _verify_schedule(sch, inputs, output): - mod = _lower_schedule(sch, inputs + [output]) - mods = [] - mods.append(mod) - mod = tvm.tir.transform.InjectRollingBuffer()(mod) + return func - def _check(stmt): - if isinstance(stmt, tvm.tir.AttrStmt): - assert stmt.attr_key != "rolling_buffer_scope", "Failed to lower rolling buffers" - tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _check) - mods.append(mod) +def _verify_schedule(sch, inputs, output): + user_pass_lists = [ + [(0, remove_rolling_buffer_attr), (0, verify_no_rolling_buffer_attr)], + [(0, tvm.tir.transform.InjectRollingBuffer()), (0, verify_no_rolling_buffer_attr)], + ] + built_funcs = [] + for user_pass_list in user_pass_lists: + with tvm.transform.PassContext(config={"tir.add_lower_pass": user_pass_list}): + built_funcs.append(tvm.build(sch, inputs + [output])) outputs = [] ctx = tvm.cpu(0) @@ -75,15 +83,9 @@ def _check(stmt): ) shape = [i.value for i in output.shape] out = tvm.nd.array(np.zeros(shape, dtype="int8"), ctx) - for mod in mods: - mod = tvm.tir.transform.StorageFlatten(64)(mod) - mod = tvm.tir.transform.NarrowDataType(32)(mod) - mod = tvm.tir.transform.LoopPartition()(mod) - mod = tvm.tir.transform.StorageRewrite()(mod) - # Build for CPU execution - f = tvm.build(mod) - f(*input_data, out) - outputs.append(out.asnumpy()) + for func in built_funcs: + func(*input_data, out) + outputs.append(out.numpy()) np.testing.assert_equal(outputs[0], outputs[1]) From f400062b2110c295a272bcfb78579ef6dc31d73c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 14:57:48 -0600 Subject: [PATCH 079/177] fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate --- .../unittest/test_tir_transform_inject_double_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 9b37bcaaacbc..0f4cc00f0702 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -30,7 +30,7 @@ def test_double_buffer(): with ib.for_range(0, n) as i: B = ib.allocate("float32", m, name="B", scope="shared") with ib.new_scope(): - ib.scope_attr(B.asobject(), "double_buffer_scope", 1) + ib.scope_attr(B.asobject().data, "double_buffer_scope", 1) with ib.for_range(0, m) as j: B[j] = A[i * 4 + j] with ib.for_range(0, m) as j: @@ -48,7 +48,7 @@ def test_double_buffer(): stmt = mod["db"].body assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert list(stmt.body.extents) == [m * 2] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] From 01d699f5725a7a3339453a13f11b1d88dcc47148 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 15:21:11 -0600 Subject: [PATCH 080/177] [UnitTest] Fixed breakage in flatten buffer unit tests. - Updated pass to allow BufferStore/BufferLoad nodes to be visited before the block's alloc buffer. - Added `T.preflattened_buffer` annotations. --- src/tir/transforms/flatten_buffer.cc | 13 ++++---- .../test_tir_transform_flatten_buffer.py | 30 ++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 0c15f4af2fa2..5977601db058 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -66,7 +66,7 @@ class BufferFlattener : public StmtExprMutator { private: explicit BufferFlattener(const Map& extern_buffer_map) { for (const auto& kv : extern_buffer_map) { - updated_extern_buffer_map_.Set(kv.first, MakeFlattenedBuffer(kv.second)); + updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second)); } } @@ -84,7 +84,7 @@ class BufferFlattener : public StmtExprMutator { } // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { - Buffer buffer = MakeFlattenedBuffer(new_block->alloc_buffers[i - 1]); + Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]); body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body)); } return body; @@ -143,8 +143,11 @@ class BufferFlattener : public StmtExprMutator { } } - Buffer MakeFlattenedBuffer(Buffer buf) { - ICHECK_EQ(buffer_remap_.count(buf), 0) << "Multiple definitions of " << buf; + Buffer GetFlattenedBuffer(Buffer buf) { + auto it = buffer_remap_.find(buf); + if (it != buffer_remap_.end()) { + return it->second; + } auto flattened = buf.GetFlattenedBuffer(); @@ -196,7 +199,7 @@ class BufferFlattener : public StmtExprMutator { template Node VisitBufferAccess(Node node) { auto flattened_indices = node->buffer->ElemOffset(node->indices); - Buffer flattened_buffer = buffer_remap_.at(node->buffer); + Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); auto writer = node.CopyOnWrite(); writer->buffer = flattened_buffer; diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 7dab0589dd9e..68b1ad338964 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -50,8 +50,10 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: @T.prim_func def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + A = T.match_buffer(a, 256, "float32") + C = T.match_buffer(c, 256, "float32") + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): B_new = T.allocate([16], "float32", "global") for j in T.serial(0, 16): @@ -85,8 +87,10 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: @T.prim_func def flattened_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + A = T.match_buffer(a, 256, "float32") + C = T.match_buffer(c, 256, "float32") + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -126,8 +130,10 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> @T.prim_func def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, (n, m), "float32") - C = T.match_buffer(c, (n, m), "float32") + A = T.match_buffer(a, n * m, "float32") + C = T.match_buffer(c, n * m, "float32") + T.preflattened_buffer(A, (n, m), "float32", data=A.data) + T.preflattened_buffer(C, (n, m), "float32", data=C.data) for i in range(0, n): B = T.allocate([m], "float32", "global") @@ -154,6 +160,8 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: def flattened_predicate_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32), "float32") C = T.match_buffer(c, (32), "float32") + T.preflattened_buffer(A, (32), "float32", data=A.data) + T.preflattened_buffer(C, (32), "float32", data=C.data) for i, j in T.grid(5, 7): if i * 7 + j < 32: @@ -176,6 +184,8 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32), "float32") C = T.match_buffer(c, (32), "float32") + T.preflattened_buffer(A, (32), "float32", data=A.data) + T.preflattened_buffer(C, (32), "float32", data=C.data) for x, z in T.grid(4, 8): C[x * 8 + z] = A[x * 8 + z] + 1.0 @@ -201,6 +211,8 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (32), "float32") D = T.match_buffer(d, (32), "float32") + T.preflattened_buffer(A, (32), "float32", data=A.data) + T.preflattened_buffer(D, (32), "float32", data=D.data) for i in range(0, 32): B = T.allocate((32,), "float32", "global") @@ -235,8 +247,10 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: @T.prim_func def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + A = T.match_buffer(a, (256,), "float32") + C = T.match_buffer(c, (256,), "float32") + T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) + T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): B_new = T.allocate([68], "float32", "global") for i1 in T.serial(0, 4): From 2f9397050013d5bbafd9333b28c80728870d2385 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 15:24:08 -0600 Subject: [PATCH 081/177] fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate --- .../unittest/test_tir_transform_combine_context_call.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 191aec4b4641..3271e6e2569a 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -29,10 +29,10 @@ def device_context(dev_id): n = te.var("n") A = ib.allocate("float32", n, name="A", scope="global") with ib.for_range(0, n, name="i") as i: - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A)) + ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data)) with ib.for_range(0, 10, name="j") as j: - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A)) - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A)) + ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A.asobject().data)) + ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data)) body = ib.get() mod = tvm.IRModule({"func": tvm.tir.PrimFunc([dev_type, n], body)}) From 17b963c5b85d6d21c55ef898aa5446341a8eb36b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 15:38:18 -0600 Subject: [PATCH 082/177] [UnitTests] Fixed breakage in test_tir_buffer.py - Updated vload test for new behavior. - Added test for offset_of, testing behavior no longer in vload. - Added null check for buffer visitor. --- src/tir/ir/buffer.cc | 2 +- src/tir/transforms/flatten_buffer.cc | 1 + tests/python/unittest/test_tir_buffer.py | 10 +++++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1e470f68ada1..684c56ddeeb7 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -308,7 +308,7 @@ Array BufferNode::ElemOffset(Array input_indices) const { } } - return output_indices; + return SimplifyArray(&ana, output_indices); } inline Array BufferOffset(const BufferNode* n, Array index, DataType dtype) { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 5977601db058..c7cc51d27113 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -198,6 +198,7 @@ class BufferFlattener : public StmtExprMutator { template Node VisitBufferAccess(Node node) { + ICHECK(node->buffer.defined()); auto flattened_indices = node->buffer->ElemOffset(node->indices); Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 5eac95ad77aa..e790ffc199e5 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -82,7 +82,15 @@ def test_buffer_vload(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103) + tvm.ir.assert_structural_equal(load.indices, [2, 3]) + + +def test_buffer_offset_of(): + m = te.size_var("m") + n = te.size_var("n") + Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) + offset = Ab.offset_of([2, 3]) + tvm.ir.assert_structural_equal(offset, [n * 2 + 103]) def test_buffer_vload_nullptr(): From 28f63394ae49f00f04e229eb75ec44775e1f8f94 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 16:26:11 -0600 Subject: [PATCH 083/177] fixup! Replacing Load/Store in codegens. --- src/tir/transforms/make_packed_api.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 368c84a98e2e..72dc20e1c03b 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -99,8 +99,8 @@ class ReturnRewriter : public StmtMutator { if (it != dummy_val_buffer_map_.end()) { info.dummy_val_buffer = it->second; } else { - info.dummy_val_buffer = - Buffer(ret_var_, dtype, {1}, {1}, ConstInt32(0), ret_var_->name_hint, 0, 0, kDefault); + info.dummy_val_buffer = Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0), + ret_var_->name_hint, 0, 0, kDefault); dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer; } @@ -116,7 +116,7 @@ class ReturnRewriter : public StmtMutator { Stmt WriteToOut(PrimExpr val) { auto info = ConvertForFFI(val); - Stmt store_val = BufferStore(info.dummy_val_buffer, val, {0}); + Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0}); Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0}); Stmt ret_zero = Evaluate(tvm::ret(0)); return SeqStmt({store_val, store_tcode, ret_zero}); @@ -166,8 +166,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())}, DataType::Int(32), "arg_type_ids"); Var v_num_packed_args("num_args", DataType::Int(32)); - Var v_out_ret_value("out_ret_value", DataType::Handle()); - Var v_out_ret_tcode("out_ret_tcode", DataType::Handle()); + Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::UInt(8)))); + Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32)))); Var v_resource_handle("resource_handle", DataType::Handle()); // The arguments of the function. Array args; From 4d6496ffa29620f6d79950bb1bc9057853f5150f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Feb 2022 16:36:01 -0600 Subject: [PATCH 084/177] [UnitTest] ComputeInline, opaque access test updates --- .../test_tir_schedule_compute_inline.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 3e31e9dd0d73..f8d767da4645 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -183,7 +183,12 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - C[vi, vj] = B[vi * 128 + vj] + 1.0 + T.evaluate( + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle" + ) + ) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -200,8 +205,17 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) - C[vi * 128 + vj] = B[vi, vj] + 1.0 - C[vi, vj] = B[vi * 16 + vj] + 1.0 + T.evaluate( + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle" + ) + ) + T.evaluate( + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle" + ) + ) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func From 6f879d044a4281b6a9d32695f92e1ea783e023c7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 7 Feb 2022 14:41:20 -0600 Subject: [PATCH 085/177] [UnitTest] Fixup, allow unit test to use `ib.pointer()[0]`. --- python/tvm/tir/ir_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index f95b6b73fadb..092adf6901c6 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -448,7 +448,7 @@ def pointer(self, content_type, name="ptr", scope=""): ptr : BufferVar The buffer var representing the buffer. """ - buffer = _buffer.decl_buffer(shape=[], dtype=content_type, name=name, scope=scope) + buffer = _buffer.decl_buffer(shape=[1], dtype=content_type, name=name, scope=scope) return BufferVar(self, buffer, content_type) def buffer_ptr(self, buf): From 4de03d4c7757a33ebac939fc870213aee8fce324 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 7 Feb 2022 14:42:19 -0600 Subject: [PATCH 086/177] fixup! Replacing Load/Store in codegens. The updated CodegenLLVM should use the BufferStore/BufferLoad convention of indexing by `sizeof(dtype)`, rather than `sizeof(dtype.element_of())`. --- src/target/llvm/codegen_llvm.cc | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 44cf72153dc0..bece4e4d5ffd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1244,14 +1244,15 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; DataType t = op->dtype; + DataType buffer_element_dtype = op->buffer->dtype; Var buffer_var = op->buffer->data; - const PrimExpr& buffer_index = op->indices[0]; + PrimExpr buffer_index = op->indices[0]; bool is_volatile = volatile_buf_.count(buffer_var.get()); llvm::Value* buffer = MakeValue(buffer_var); llvm::Value* index = MakeValue(buffer_index); - if (t.lanes() == 1) { + if (t.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); @@ -1272,9 +1273,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes, t.lanes()); + ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + TypedPointer buffer_ptr = + CreateBufferPtr(buffer_element_dtype, buffer, MakeValue(ramp->base)); unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); buffer_ptr.type = DTypeToLLVMType(t); @@ -1382,7 +1384,8 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; - DataType t = op->value.dtype(); + DataType value_dtype = op->value.dtype(); + DataType buffer_element_dtype = op->buffer->dtype; Var buffer_var = op->buffer->data; PrimExpr buffer_index = op->indices[0]; @@ -1391,10 +1394,10 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { llvm::Value* index = MakeValue(buffer_index); llvm::Value* value = MakeValue(op->value); - if (t.lanes() == 1) { + if (value_dtype.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; - GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); + GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits); + TypedPointer buffer_ptr = CreateBufferPtr(value_dtype, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1409,13 +1412,14 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { if (const RampNode* ramp = buffer_index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; - GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes, t.lanes()); + GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits); + ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + TypedPointer buffer_ptr = + CreateBufferPtr(buffer_element_dtype, buffer, MakeValue(ramp->base)); unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); - buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.type = DTypeToLLVMType(value_dtype); buffer_ptr.addr = builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 @@ -1430,11 +1434,11 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { } } } - ICHECK_GE(t.bits(), 8); + ICHECK_GE(value_dtype.bits(), 8); // scalarized store. - int basic_align = t.bits() / 8; + int basic_align = value_dtype.bits() / 8; auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(value_dtype.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, From 418d1aaa9a7a31cffd9807ddce8a94db301008b5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 7 Feb 2022 14:43:26 -0600 Subject: [PATCH 087/177] fixup! Replacing Store/Load in lowering/legalization passes. BF16Legalize should also update the preflattened_buffer_map, since it is overwriting the `BufferNode::data` stored in the buffer_map. --- src/tir/transforms/bf16_legalize.cc | 38 +++++++++++++++++++++++++--- src/tir/transforms/vectorize_loop.cc | 29 +++++++++++++-------- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index e398f75cb0fe..9271b2e7dfc5 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -299,9 +299,10 @@ class BF16LowerRewriter : public StmtExprMutator { } void AlterBuffers(PrimFuncNode* op) { - std::vector> changes; + Map new_buffer_map; for (auto& itr : op->buffer_map) { + auto param_var = itr.first; auto oldbuf = itr.second; if (oldbuf->dtype.is_bfloat16()) { DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); @@ -311,14 +312,43 @@ class BF16LowerRewriter : public StmtExprMutator { oldbuf->buffer_type); buffer_remap_[oldbuf] = newbuf; var_remap_[oldbuf->data] = buffer_var; - changes.emplace_back(itr.first, newbuf); + new_buffer_map.Set(param_var, newbuf); } else { - changes.emplace_back(itr); + new_buffer_map.Set(param_var, oldbuf); + } + } + + // Most passes do not change the preflattened buffer map, nor + // should they change it. This is an exception, because the Var + // associated with the `BufferNode::data` in + // `PrimFunc::buffer_map` may be replaced, and the corresponding + // Var in the `PrimFunc::preflattened_buffer_map` must also be + // replaced. + Map new_preflattened_buffer_map; + for (auto& itr : op->preflattened_buffer_map) { + auto param_var = itr.first; + auto oldbuf = itr.second; + if (oldbuf->dtype.is_bfloat16()) { + auto it = new_buffer_map.find(param_var); + ICHECK(it != new_buffer_map.end()) + << "PrimFunc parameter " << param_var->name_hint + << " is associated with the pre-flattened buffer " << oldbuf->name + << ", but isn't associated with any post-flatten buffer."; + const Buffer& flatbuf = (*it).second; + DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); + auto newbuf = Buffer(flatbuf->data, dtype, oldbuf->shape, oldbuf->strides, + oldbuf->elem_offset, oldbuf->name, oldbuf->data_alignment, + oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap_[oldbuf] = newbuf; + new_preflattened_buffer_map.Set(param_var, newbuf); + } else { + new_preflattened_buffer_map.Set(param_var, oldbuf); } } if (buffer_remap_.size() != 0) { - op->buffer_map = Map(changes.begin(), changes.end()); + op->buffer_map = new_buffer_map; + op->preflattened_buffer_map = new_preflattened_buffer_map; } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 3fed1e193de9..20f67e0e40b0 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -434,26 +434,33 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (!indices.same_as(op->indices) || !value.same_as(op->value)) { - int index_lanes = 1; - for (const auto& index : indices) { - index_lanes *= index.dtype().lanes(); + // How many lanes of indexing are present in the index and + // buffer element type, excluding the last index. T + int other_index_lanes = op->buffer->dtype.lanes(); + for (size_t i = 0; i < indices.size() - 1; i++) { + other_index_lanes *= indices[i].dtype().lanes(); } - int lanes = std::max(index_lanes, value.dtype().lanes()); + // The total number of lanes of indexing, including the last index. + int index_lanes = other_index_lanes * indices[indices.size() - 1].dtype().lanes(); - int last_index_lanes = indices[indices.size() - 1].dtype().lanes(); - int earlier_index_lanes = index_lanes / last_index_lanes; + // The total number of lanes in this store operation. Either + // the index or the value will be broadcast out to this number + // of lanes, depending on which has more lanes. + int total_lanes = std::max(index_lanes, value.dtype().lanes()); + + ICHECK_EQ(total_lanes % other_index_lanes, 0) + << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes + << " lanes of storage location by changing the last index."; + int last_index_lanes = total_lanes / other_index_lanes; // Broadcast the last index such that the total number of index // lanes matches the desired number. - ICHECK_EQ(lanes % last_index_lanes, 0) - << "Cannot produce location with " << value.dtype().lanes(); - indices.Set(indices.size() - 1, - BroadcastTo(indices[indices.size() - 1], lanes / earlier_index_lanes)); + indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes)); auto writer = store.CopyOnWrite(); writer->indices = indices; - writer->value = BroadcastTo(value, lanes); + writer->value = BroadcastTo(value, total_lanes); } return std::move(store); From 0e614af3103051c25c2c9b3a6ad9cf62cf5dcdb7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 09:10:00 -0600 Subject: [PATCH 088/177] fixup! Replacing all T.store TIR calls. --- tests/python/unittest/test_target_codegen_llvm.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c2a7326d517a..45d8b8725c82 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -51,7 +51,7 @@ def test_llvm_void_intrin(): ib = tvm.tir.ir_builder.create() A = ib.pointer("uint8", name="A") # Create an intrinsic that returns void. - x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A) + x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A.asobject().data) ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) @@ -672,13 +672,12 @@ def my_vectorize(): def vectorizer(op): store = op.body idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1, "int32"), 8) - all_ones = tvm.tir.const(1, "int32x8") value = store.value b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in range(7, -1, -1)]) - new_a = tvm.tir.Load("int32x8", value.a.buffer_var, idx, all_ones) - new_b = tvm.tir.Load("int32x8", value.b.buffer_var, b_idx, all_ones) + new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) + new_b = tvm.tir.BufferLoad(value.b.buffer, [b_idx]) value = new_a + new_b - return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) + return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) def _transform(f, *_): return f.with_body( @@ -925,7 +924,7 @@ def threadpool_nested_parallel_loop( T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): - B.data[i * 4 + j] = A.data[i * 4 + j] * 2.0 + B[i, j] = A[i, j] * 2.0 with pytest.raises(tvm.TVMError) as e: tvm.build({"llvm": tvm.IRModule.from_expr(threadpool_nested_parallel_loop)}) From 3b3e7fbe6e249cb0234c5937a1db6646a1a70d7f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 09:53:10 -0600 Subject: [PATCH 089/177] Fixed failing codegen c host unit tests. - Generated functions were making `uint8_t*` parameter arguments for array handle for return value, rather than the earlier `void*`. - New parameter type was due to using `PointerType(PrimType(DataType::UInt(8)))` as the type annotation, to be usable as `BufferNode::data`. - Changing to `PointerType(PrimType(DataType::Void()))` still allows usage as buffer, more appropriately expresses semantics. - Updated C codegens to allow `void*` types to be generated from variables with type annotation, in addition to the previous behavior of `DataType::Handle()` variables without type annotation. --- src/target/source/codegen_c_host.cc | 4 ++++ src/target/source/codegen_cuda.cc | 6 ++++++ src/target/source/codegen_metal.cc | 5 +++++ src/target/source/codegen_opencl.cc | 4 ++++ src/target/source/codegen_source_base.cc | 4 ++++ src/tir/transforms/make_packed_api.cc | 2 +- 6 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 515cdccb88fb..6f1ec93e2d11 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -142,6 +142,10 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + if (t.is_void()) { + os << "void"; + return; + } if (t == DataType::Bool()) { os << "bool"; return; diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 28d972232f5f..0b75167c6df2 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -171,6 +171,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + + if (t.is_void()) { + os << "void"; + return; + } + bool fail = false; if (t.is_float()) { switch (t.bits()) { diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b44afec57d5d..a76da36ea725 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -177,6 +177,11 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + + if (t.is_void()) { + os << "void"; + return; + } if (t == DataType::Bool()) { os << "bool"; return; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 28277077179f..8d0179c183f2 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -174,6 +174,10 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "void*"; return; } + if (t.is_void()) { + os << "void"; + return; + } if (t == DataType::Bool()) { os << "bool"; return; diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 5dcf1587bdb9..5acb42071b62 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -119,6 +119,10 @@ void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT( os << "void*"; return; } + if (type.is_void()) { + os << "void"; + return; + } if (type.is_float()) { if (type.bits() == 32) { os << "float"; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 72dc20e1c03b..a31349fe1c07 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -166,7 +166,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())}, DataType::Int(32), "arg_type_ids"); Var v_num_packed_args("num_args", DataType::Int(32)); - Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::UInt(8)))); + Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void()))); Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32)))); Var v_resource_handle("resource_handle", DataType::Handle()); // The arguments of the function. From 7c6ded0df954ef5cb4f129381c05f2d6dcdecf0d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 10:40:26 -0600 Subject: [PATCH 090/177] Fixup, StorageFlatten when applied to post-StorageRewrite functions. Identified in a test that applied `tvm.lower`, then `tvm.build` on the result. If the result of an allocate node is used as the backing buffer for multiple buffers, such as the output of the StorageRewrite pass, then StorageFlatten would erroneously think that the second occurrence was an usage without earlier definition. --- src/tir/transforms/storage_flatten.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 859b735f7000..57f02fc2f751 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -534,13 +534,12 @@ class BufferStrideLegalize : public StmtExprMutator { template Node VisitBufferAccess(Node node) { auto alloc_key = node->buffer->data.get(); - if (allocate_node_var_.count(alloc_key)) { + if (!buf_map_.count(node->buffer) && allocate_node_var_.count(alloc_key)) { BufferEntry entry; entry.remap_to = WithStrides(node->buffer); entry.in_scope = true; entry.is_external = false; buf_map_[node->buffer] = entry; - allocate_node_var_.erase(alloc_key); } auto it = buf_map_.find(node->buffer); @@ -1039,11 +1038,10 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& GetBufferEntry(Buffer buffer) { auto alloc_key = buffer->data.get(); - if (allocate_node_var_.count(alloc_key)) { + if (!buf_map_.count(buffer.get()) && allocate_node_var_.count(alloc_key)) { BufferEntry entry; entry.buffer = buffer; buf_map_[buffer.get()] = std::move(entry); - allocate_node_var_.erase(alloc_key); } auto it = buf_map_.find(buffer.get()); @@ -1514,12 +1512,11 @@ class StorageFlattener : public StmtExprMutator { const BufferEntry& GetBufferEntry(Buffer buffer) { auto alloc_key = buffer->data.get(); - if (allocate_node_var_.count(alloc_key)) { + if (!buf_map_.count(buffer) && allocate_node_var_.count(alloc_key)) { BufferEntry entry; entry.buffer = buffer; entry.flattened_buffer = buffer.GetFlattenedBuffer(); buf_map_[buffer] = std::move(entry); - allocate_node_var_.erase(alloc_key); } auto it = buf_map_.find(buffer); From 068b1793f99fbbfa0319ecb2824e86a26389be94 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 12:17:53 -0600 Subject: [PATCH 091/177] fixup, StorageFlatten When flattening a boolean buffer, the backing buffer should have type int8, not the preflattened buffer. --- src/tir/transforms/storage_flatten.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 57f02fc2f751..e0707e53562d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1348,8 +1348,8 @@ class StorageFlattener : public StmtExprMutator { // dedicated pass. // Boolean tensors are backed by a Int8 array. - if (e.buffer->dtype == DataType::Bool()) { - auto writer = e.buffer.CopyOnWrite(); + if (e.flattened_buffer->dtype == DataType::Bool()) { + auto writer = e.flattened_buffer.CopyOnWrite(); writer->dtype = DataType::Int(8); } From b5a1428c783f98f22d6b428bced072288d346114 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 13:13:24 -0600 Subject: [PATCH 092/177] Bugfix, correctly represent void* in LLVM IR. --- src/target/llvm/codegen_llvm.cc | 7 +++++++ tests/python/unittest/test_runtime_module_load.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bece4e4d5ffd..743c87823ba2 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -437,6 +437,13 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { return DTypeToLLVMType(ptr->dtype); } else if (auto* ptr = type.as()) { + // LLVM IR doesn't allow void*, so we need to recognize this + // pattern explicitly. + if (auto* primtype = ptr->element_type.as()) { + if (primtype->dtype.is_void()) { + return t_void_p_; + } + } // TODO(tvm-team) consider put storage scope into the pointer type. return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace()); } else if (IsVoidType(type)) { diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index f17a615ce2c1..9d067630879a 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -59,7 +59,7 @@ def save_object(names): 0, n - 1, tvm.tir.ForKind.SERIAL, - tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1), + tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]), ) mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") From 0b5b8402072a9860a96e19c1c100256ab265b975 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 13:16:58 -0600 Subject: [PATCH 093/177] Update, replace tir.Load with tir.BufferLoad --- tests/python/unittest/test_runtime_module_based_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 162e10280d13..80dd4f7fd693 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -595,7 +595,7 @@ def make_func(symbol): 0, n - 1, tvm.tir.ForKind.SERIAL, - tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1), + tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]), ) return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol) From 249285a907ee753db57778c784a4182488832312 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 13:32:23 -0600 Subject: [PATCH 094/177] Added TVMScript error check for matching buffer/index dimensionality Needed for tests/python/unittest/test_tvmscript_error_report.py::test_high_dim_store --- python/tvm/script/parser.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 8beff4eb27f4..afff26ef84c3 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -615,6 +615,12 @@ def transform_SubscriptAssign(self, node): rhs = self.transform(node.params[2]) rhs_span = tvm_span_from_synr(node.params[2].span) if isinstance(symbol, tvm.tir.Buffer): + if len(indexes) != len(symbol.shape): + self.report_error( + f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, " + f"cannot be indexed by {len(indexes)}-dimensional indices.", + node.params[1].span, + ) # BufferStore return tvm.tir.BufferStore( symbol, From 7b98cae90fe0d4c03c4bea30998841b7e27cb5af Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 14:50:07 -0600 Subject: [PATCH 095/177] Bugfix, correct return type when lowering custom datatype. --- src/target/llvm/codegen_llvm.cc | 1 + src/tir/transforms/lower_custom_datatypes.cc | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 743c87823ba2..a521ae6b798b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -793,6 +793,7 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); ICHECK(btype != nullptr); llvm::Type* llvm_type = DTypeToLLVMType(t); + ICHECK(llvm_type) << "Could not make LLVM type to represent TVM type " << t; llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace()); if (btype != ttype) { buffer = builder_->CreatePointerCast(buffer, ttype); diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 16afa1133f68..fdb064076bca 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -115,7 +115,17 @@ class CustomDatatypesLowerer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - return VisitBufferAccess(std::move(node)); + auto modified = VisitBufferAccess(node); + + // Not needed for BufferStoreNode, so we can't just call + // LegalizeDtype() in VisitBufferAccess. + if (node.same_as(modified)) { + return std::move(node); + } else { + auto writer = modified.CopyOnWrite(); + writer->LegalizeDtype(); + return std::move(modified); + } } Stmt VisitStmt_(const BufferStoreNode* op) final { From 75819ba795afbff45b4fa9aa800d77ed82b83afe Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Feb 2022 14:53:32 -0600 Subject: [PATCH 096/177] Bugfix, removed unused primfunc from test_tvmscript_complete.py --- tests/python/unittest/test_tvmscript_complete.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 882745704693..105b4a2d6a3f 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -314,12 +314,6 @@ def test_complete_alloc_buffer(): tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) -@T.prim_func -def load_var() -> None: - d = T.var("float32") - d[1] = d[1] - - if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() From a0969891daeac3a3ed03617d6c6a12d4b2fece3e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Feb 2022 09:16:25 -0600 Subject: [PATCH 097/177] Updated test_meta_schedule_postproc_verify_gpu_code.py TIR Replaced Load/Store with BufferLoad/BufferStore. --- ..._meta_schedule_postproc_verify_gpu_code.py | 76 ++++++++++++------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index bebfec6122b3..bbf05629315f 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -59,8 +59,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") @@ -71,17 +71,22 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner] # fmt: on @tvm.script.ir_module @@ -96,8 +101,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([6400000], "float32", "local") @@ -108,17 +113,22 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner]# fmt: on @tvm.script.ir_module @@ -133,8 +143,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") @@ -145,17 +155,22 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner]# fmt: on @tvm.script.ir_module @@ -170,8 +185,8 @@ def main(a: T.handle, b: T.handle) -> None: blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + A = T.match_buffer(a, [14*14*256*256], dtype="float32") + B = T.match_buffer(b, [14*14*512*256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") @@ -182,17 +197,22 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 800000) for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( + 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, + A[T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], + T.broadcast(T.float32(0), 4), + dtype="float32x4", + ) for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner]# fmt: on # fmt: on From cb46d7eea3ddc2871ce526b7056a70f9cd0bd06e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Feb 2022 09:17:03 -0600 Subject: [PATCH 098/177] Allowed ramp nodes with buffer use analysis. --- src/tir/analysis/block_access_region_detector.cc | 4 ++-- src/tir/ir/stmt.cc | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 03e02064f798..974f6ecd644f 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -147,7 +147,7 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -199,7 +199,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 484f8aa851f5..b6eec727e2c2 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -689,7 +689,12 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) { BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { Array region; for (const PrimExpr& index : indices) { - region.push_back(Range::FromMinExtent(index, 1)); + if (const RampNode* ramp_index = index.as()) { + region.push_back( + Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes)); + } else { + region.push_back(Range::FromMinExtent(index, 1)); + } } return BufferRegion(buffer, region); } From 5ae79fdffb0a092c64bbe7b086c5d30e2f622df8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Feb 2022 10:05:53 -0600 Subject: [PATCH 099/177] Updated tests in test_meta_schedule_postproc_verify_gpu_code.py Needed dummy writes to prevent buffer resizing, in order to trigger the verification failure due to memory limits. --- .../test_meta_schedule_postproc_verify_gpu_code.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index bbf05629315f..b811ef31bf16 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -114,6 +114,10 @@ def main(a: T.handle, b: T.handle) -> None: T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) + # Access of the last element of B_local prevents buffer + # compacting from reducing the amount of shared memory + # used. + B_local[6400000-1 + ff_c_init*8] = 0.0 for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): Apad_shared[T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( @@ -164,6 +168,10 @@ def main(a: T.handle, b: T.handle) -> None: T.broadcast(T.float32(0), 4), dtype="float32x4", ) + # Access of the last element of Apad_shared prevents + # buffer compacting from reducing the amount of shared + # memory used. + Apad_shared[512000-1] = 0.0 for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] @@ -230,6 +238,8 @@ def test_postproc_verify_gpu_1(): mod = Conv2dCuda1 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") + # Should fail due to too much local memory per block (large + # B_local allocation). assert not ctx.postprocs[0].apply(sch) @@ -237,6 +247,8 @@ def test_postproc_verify_gpu_2(): mod = Conv2dCuda2 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") + # Should fail due to too much local memory per block (large + # Apad_shared allocation). assert not ctx.postprocs[0].apply(sch) @@ -244,6 +256,8 @@ def test_postproc_verify_gpu_3(): mod = Conv2dCuda3 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") + # Should fail due to too many threads per block (large + # threadIdx.x extent). assert not ctx.postprocs[0].apply(sch) From c0c9329f61616ac9d602b5c7f991157538ec5510 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Feb 2022 10:28:43 -0600 Subject: [PATCH 100/177] Updated TIR examples to be compatible with buffer dimension check. --- .../test_tir_analysis_calculate_workspace.py | 18 +- ...t_tir_analysis_detect_buffer_access_lca.py | 2 +- .../unittest/test_tir_lower_match_buffer.py | 2 +- .../test_tir_schedule_cache_read_write.py | 6 +- .../unittest/test_tir_schedule_reorder.py | 4 +- .../unittest/test_tir_schedule_split_fuse.py | 6 +- ..._tir_transform_convert_for_loops_serial.py | 12 +- tests/python/unittest/test_tir_usmp_algo.py | 56 ++--- ...st_tir_usmp_analysis_extract_bufferinfo.py | 212 +++++++++--------- ...orm_convert_pool_allocations_to_offsets.py | 32 +-- tests/python/unittest/test_tir_usmp_utils.py | 6 +- .../unittest/test_tvmscript_roundtrip.py | 12 +- 12 files changed, 184 insertions(+), 184 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 89e0791e457d..e866e996f174 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -29,7 +29,7 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand placeholder_147 = T.match_buffer(placeholder_144, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_148 = T.match_buffer(placeholder_145, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_149 = T.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_49 = T.match_buffer(T_cast_48, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_cast_49 = T.match_buffer(T_cast_48, [1*14*14*512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_22 = T.allocate([131072], "int16", "global") DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") @@ -60,27 +60,27 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl placeholder_165 = T.match_buffer(placeholder_162, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_166 = T.match_buffer(placeholder_163, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_167 = T.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_77 = T.match_buffer(T_cast_76, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_25 = T.allocate([1, 16, 16, 512], "int16", "global") + PaddedInput_25 = T.allocate([131072], "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16") - T_add_11 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: + T_add_11 = T.allocate([100352], "int32", "global") + with T.allocate([100352], "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47]) - compute_22 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: + compute_22 = T.allocate([100352], "int32", "global") + with T.allocate([100352], "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32") - T_cast_79 = T.allocate([1, 14, 14, 512], "uint8", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as compute_23: + T_cast_79 = T.allocate([100352], "uint8", "global") + with T.allocate([100352], "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 9b688bb857f2..49121614ffa0 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -70,7 +70,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: @T.prim_func def lca_is_func_root(a: T.handle) -> None: A = T.match_buffer(a, [0, 0], "float32") - A[0] = 1.0 + A[0, 0] = 1.0 @T.prim_func diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 3a9af20a41b5..623cff6420e6 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -476,7 +476,7 @@ def fail_match_store(a: T.handle) -> None: T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], ()) - sub_A[0] = 1 + sub_A[()] = 1 @T.prim_func diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 7feb82a095fe..fd7066b04ebe 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -80,7 +80,7 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) - D[vi * 128 + vj] = A[vi * 128 + vj] + D[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -272,7 +272,7 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(D[vi, vj]) - D[vi * 128 + vj] = A_global[vi * 128 + vj] + D[vi, vj] = A_global[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) @@ -481,7 +481,7 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D_global[vi, vj]) - D_global[vi * 128 + vj] = A[vi * 128 + vj] + D_global[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index bfa469b86fbe..f62a316f8013 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -153,7 +153,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - A[vi * 16 + vj] = 1 + A[vi, vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -171,7 +171,7 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - A[vi * 16 + vj] = 1 + A[vi, vj] = 1 for j, i in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 576a8a99ef69..ea3a410fff96 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -273,7 +273,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) - A[vi * 16 + vj] = 1 + A[vi, vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) @@ -292,7 +292,7 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) - A[((vi * 16) + vj)] = 1 + A[vi, vj] = 1 for i_j_fused in T.serial(0, 256): with T.block("B"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) @@ -312,7 +312,7 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) - A[((vi * 16) + vj)] = 1 + A[vi, vj] = 1 for i, j0, j1 in T.grid(16, 4, 4): with T.block("B"): vi = T.axis.S(16, i) diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index d3b8fe40dbf1..38431705611b 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -26,18 +26,18 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) - placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") + PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): - Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") + Conv2dOutput_3 = T.allocate([1], "int32", "global") Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 1cde5a1d6b5c..548fd96676a0 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -298,9 +298,9 @@ class MobilenetStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -310,10 +310,10 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -332,8 +332,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -418,9 +418,9 @@ class ResnetStructure: def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -429,10 +429,10 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -450,10 +450,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -472,11 +472,11 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") - placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") - placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") - placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") - T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -512,10 +512,10 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 78f97b43c00d..22b3d5826b3b 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -99,9 +99,9 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -111,10 +111,10 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -133,8 +133,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -207,10 +207,10 @@ class ParallelSerialMixedForLoops: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) - placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): @@ -248,10 +248,10 @@ class AllSerialForLoops: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) - placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): @@ -330,8 +330,8 @@ class InceptionStructure: def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d", "tir.noalias": True}) - placeholder_1 = T.match_buffer(placeholder, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - tensor_1 = T.match_buffer(tensor, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_1 = T.match_buffer(placeholder, [602112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [249], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused in T.serial(0, 28): for ax2 in T.serial(0, 28): @@ -344,9 +344,9 @@ def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -356,8 +356,8 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast", "tir.noalias": True}) - placeholder_7 = T.match_buffer(placeholder_6, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_1 = T.match_buffer(T_cast, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_7 = T.match_buffer(placeholder_6, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_1 = T.match_buffer(T_cast, [249], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_2 in T.serial(0, 28): for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): @@ -367,11 +367,11 @@ def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_concatenate", "tir.noalias": True}) - placeholder_12 = T.match_buffer(placeholder_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_concat_1 = T.match_buffer(T_concat, [1, 28, 28, 256], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_13 = T.match_buffer(placeholder_9, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_14 = T.match_buffer(placeholder_11, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_15 = T.match_buffer(placeholder_10, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_12 = T.match_buffer(placeholder_8, [50176], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_concat_1 = T.match_buffer(T_concat, [313], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_13 = T.match_buffer(placeholder_9, [100352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_14 = T.match_buffer(placeholder_11, [25088], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_15 = T.match_buffer(placeholder_10, [25088], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_3 in T.serial(0, 28): for ax2_3, ax3 in T.grid(28, 256): @@ -381,10 +381,10 @@ def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.h def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_3 = T.match_buffer(T_cast_2, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_19 = T.match_buffer(placeholder_16, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_20 = T.match_buffer(placeholder_17, [4096], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_21 = T.match_buffer(placeholder_18, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_3 = T.match_buffer(T_cast_2, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput = T.allocate([200704], "int16", "global") for i0_i1_fused in T.serial(0, 56): @@ -403,10 +403,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_25 = T.match_buffer(placeholder_22, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_26 = T.match_buffer(placeholder_23, [1, 1, 192, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_27 = T.match_buffer(placeholder_24, [1, 1, 1, 96], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_5 = T.match_buffer(T_cast_4, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_25 = T.match_buffer(placeholder_22, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_26 = T.match_buffer(placeholder_23, [18432], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_27 = T.match_buffer(placeholder_24, [96], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_5 = T.match_buffer(T_cast_4, [153], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_1 = T.allocate([150528], "int16", "global") for i0_i1_fused_1 in T.serial(0, 28): @@ -424,8 +424,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -442,10 +442,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", "tir.noalias": True}) - placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [12288], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [121], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_2 = T.allocate([150528], "int16", "global") for i0_i1_fused_2 in T.serial(0, 28): @@ -464,8 +464,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placehol def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast_1", "tir.noalias": True}) - placeholder_37 = T.match_buffer(placeholder_36, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_11 = T.match_buffer(T_cast_10, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_37 = T.match_buffer(placeholder_36, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_11 = T.match_buffer(T_cast_10, [249], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_3 = T.allocate([150528], "uint8", "global") for ax0_ax1_fused_6 in T.serial(0, 28): @@ -482,10 +482,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_1 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", "tir.noalias": True}) - placeholder_41 = T.match_buffer(placeholder_38, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_42 = T.match_buffer(placeholder_39, [1, 1, 192, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_43 = T.match_buffer(placeholder_40, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_13 = T.match_buffer(T_cast_12, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_41 = T.match_buffer(placeholder_38, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_42 = T.match_buffer(placeholder_39, [6144], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_43 = T.match_buffer(placeholder_40, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_13 = T.match_buffer(T_cast_12, [89], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.serial(0, 28): @@ -503,10 +503,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) - placeholder_47 = T.match_buffer(placeholder_44, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_48 = T.match_buffer(placeholder_45, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_49 = T.match_buffer(placeholder_46, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_15 = T.match_buffer(T_cast_14, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_47 = T.match_buffer(placeholder_44, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_48 = T.match_buffer(placeholder_45, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_49 = T.match_buffer(placeholder_46, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_15 = T.match_buffer(T_cast_14, [73], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_4 = T.allocate([150528], "int16", "global") for i0_i1_fused_4 in T.serial(0, 28): @@ -524,10 +524,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(pla def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", "tir.noalias": True}) - placeholder_53 = T.match_buffer(placeholder_50, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_54 = T.match_buffer(placeholder_51, [3, 3, 16, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_55 = T.match_buffer(placeholder_52, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_17 = T.match_buffer(T_cast_16, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_53 = T.match_buffer(placeholder_50, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_54 = T.match_buffer(placeholder_51, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_55 = T.match_buffer(placeholder_52, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_17 = T.match_buffer(T_cast_16, [89], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_5 = T.allocate([14400], "int16", "global") for i0_i1_fused_5 in T.serial(0, 30): @@ -545,10 +545,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", "tir.noalias": True}) - placeholder_59 = T.match_buffer(placeholder_56, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_60 = T.match_buffer(placeholder_57, [3, 3, 96, 128], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_61 = T.match_buffer(placeholder_58, [1, 1, 1, 128], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_19 = T.match_buffer(T_cast_18, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_59 = T.match_buffer(placeholder_56, [75264], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_60 = T.match_buffer(placeholder_57, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_61 = T.match_buffer(placeholder_58, [128], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_19 = T.match_buffer(T_cast_18, [185], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_6 = T.allocate([86400], "int16", "global") for i0_i1_fused_6 in T.serial(0, 30): @@ -568,10 +568,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "T.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -590,10 +590,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) - placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_8 = T.allocate([215296], "int16", "global") for i0_i1_fused_8 in T.serial(0, 58): @@ -1107,8 +1107,8 @@ class MultipleCallsToSamePrimFuncModule: def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_trans: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform_1", "tir.noalias": True}) - placeholder_1 = T.match_buffer(placeholder, [1, 3, 24, 12], dtype="float32") - T_layout_trans_1 = T.match_buffer(T_layout_trans, [1, 1, 24, 12, 3], dtype="float32") + placeholder_1 = T.match_buffer(placeholder, [864], dtype="float32") + T_layout_trans_1 = T.match_buffer(T_layout_trans, [41], dtype="float32") # body for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): T_layout_trans_1[ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner] = placeholder_1[ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3] @@ -1117,15 +1117,15 @@ def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_tran def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_conv2d_NCHWc", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 1, 24, 12, 3], dtype="float32") - placeholder_5 = T.match_buffer(placeholder_3, [1, 1, 3, 3, 3, 3], dtype="float32") - conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [1, 1, 24, 12, 3], dtype="float32") + placeholder_4 = T.match_buffer(placeholder_2, [864], dtype="float32") + placeholder_5 = T.match_buffer(placeholder_3, [81], dtype="float32") + conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [41], dtype="float32") # body - data_pad = T.allocate([1, 1, 26, 14, 3], "float32", "global") + data_pad = T.allocate([1092], "float32", "global") for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): data_pad[i0_i1_fused_i2_fused * 42 + i3 * 3 + i4] = T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32") for n_oc_chunk_fused_oh_fused in T.serial(0, 24): - conv2d_NCHWc_global = T.allocate([1, 1, 1, 12, 3], "float32", "global") + conv2d_NCHWc_global = T.allocate([36], "float32", "global") for oc_block_c_init in T.serial(0, 3): conv2d_NCHWc_global[oc_block_c_init] = T.float32(0) for oc_block_c_init in T.serial(0, 3): @@ -1182,23 +1182,23 @@ def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeh def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add_add_multiply_add", "tir.noalias": True}) - placeholder_11 = T.match_buffer(placeholder_6, [1, 3, 24, 12], dtype="float32") - placeholder_12 = T.match_buffer(placeholder_7, [1, 3, 24, 12], dtype="float32") - placeholder_13 = T.match_buffer(placeholder_8, [3, 1, 1], dtype="float32") - placeholder_14 = T.match_buffer(placeholder_9, [3, 1, 1], dtype="float32") - placeholder_15 = T.match_buffer(placeholder_10, [3, 1, 1], dtype="float32") - T_add_1 = T.match_buffer(T_add, [1, 3, 24, 12], dtype="float32") + placeholder_11 = T.match_buffer(placeholder_6, [864], dtype="float32") + placeholder_12 = T.match_buffer(placeholder_7, [864], dtype="float32") + placeholder_13 = T.match_buffer(placeholder_8, [3], dtype="float32") + placeholder_14 = T.match_buffer(placeholder_9, [3], dtype="float32") + placeholder_15 = T.match_buffer(placeholder_10, [3], dtype="float32") + T_add_1 = T.match_buffer(T_add, [864], dtype="float32") # body for ax0_ax1_fused_ax2_fused in T.serial(0, 72): - T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") - with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: + T_softmax_norm = T.allocate([12], "float32", "global") + with T.allocate([1], "float32", "global") as T_softmax_maxelem: T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_11[ax0_ax1_fused_ax2_fused * 12 + k]) - T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + T_softmax_exp = T.allocate([12], "float32", "global") for i3 in T.serial(0, 12): T_softmax_exp[i3] = T.exp(placeholder_11[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") - T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") + T_softmax_expsum = T.allocate([1], "float32", "global") T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] @@ -1211,13 +1211,13 @@ def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", "tir.noalias": True}) - placeholder_18 = T.match_buffer(placeholder_16, [72, 12], dtype="float32") - placeholder_19 = T.match_buffer(placeholder_17, [2, 12, 6], dtype="float32") - T_relu_1 = T.match_buffer(T_relu, [72, 12], dtype="float32") + placeholder_18 = T.match_buffer(placeholder_16, [864], dtype="float32") + placeholder_19 = T.match_buffer(placeholder_17, [144], dtype="float32") + T_relu_1 = T.match_buffer(T_relu, [864], dtype="float32") # body for ax1_outer_ax0_outer_fused in T.serial(0, 18): - compute = T.allocate([8, 6], "float32", "global") - with T.allocate([8, 6], "float32", "global") as compute_global: + compute = T.allocate([48], "float32", "global") + with T.allocate([48], "float32", "global") as compute_global: for x_c_init in T.serial(0, 6): compute_global[x_c_init] = T.float32(0) for x_c_init in T.serial(0, 6): @@ -1274,8 +1274,8 @@ def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape_1", "tir.noalias": True}) - placeholder_21 = T.match_buffer(placeholder_20, [1, 3, 24, 12], dtype="float32") - T_reshape_1 = T.match_buffer(T_reshape, [72, 12], dtype="float32") + placeholder_21 = T.match_buffer(placeholder_20, [864], dtype="float32") + T_reshape_1 = T.match_buffer(T_reshape, [864], dtype="float32") # body for ax0, ax1_inner in T.grid(72, 12): T_reshape_1[ax0 * 12 + ax1_inner] = placeholder_21[ax0 * 12 + ax1_inner] @@ -1284,8 +1284,8 @@ def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform", "tir.noalias": True}) - placeholder_23 = T.match_buffer(placeholder_22, [1, 1, 24, 12, 3], dtype="float32") - T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [1, 3, 24, 12], dtype="float32") + placeholder_23 = T.match_buffer(placeholder_22, [864], dtype="float32") + T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [864], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): T_layout_trans_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_23[ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused] @@ -1294,8 +1294,8 @@ def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_tra def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape", "tir.noalias": True}) - placeholder_25 = T.match_buffer(placeholder_24, [72, 12], dtype="float32") - T_reshape_3 = T.match_buffer(T_reshape_2, [1, 3, 24, 12], dtype="float32") + placeholder_25 = T.match_buffer(placeholder_24, [864], dtype="float32") + T_reshape_3 = T.match_buffer(T_reshape_2, [864], dtype="float32") # body for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): T_reshape_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_25[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] @@ -1304,20 +1304,20 @@ def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add", "tir.noalias": True}) - placeholder_28 = T.match_buffer(placeholder_26, [1, 3, 24, 12], dtype="float32") - placeholder_29 = T.match_buffer(placeholder_27, [1, 3, 24, 12], dtype="float32") - T_add_3 = T.match_buffer(T_add_2, [1, 3, 24, 12], dtype="float32") + placeholder_28 = T.match_buffer(placeholder_26, [864], dtype="float32") + placeholder_29 = T.match_buffer(placeholder_27, [864], dtype="float32") + T_add_3 = T.match_buffer(T_add_2, [864], dtype="float32") # body for ax0_ax1_fused_ax2_fused in T.serial(0, 72): - T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") - with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: + T_softmax_norm = T.allocate([12], "float32", "global") + with T.allocate([1], "float32", "global") as T_softmax_maxelem: T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) for k in T.serial(0, 12): T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_28[ax0_ax1_fused_ax2_fused * 12 + k]) - T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + T_softmax_exp = T.allocate([12], "float32", "global") for i3 in T.serial(0, 12): T_softmax_exp[i3] = T.exp(placeholder_28[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") - T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") + T_softmax_expsum = T.allocate([1], "float32", "global") T_softmax_expsum[0] = T.float32(0) for k in T.serial(0, 12): T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] @@ -1330,8 +1330,8 @@ def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27 def run_model(data: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) - data_buffer = T.match_buffer(data, [1, 3, 24, 12], dtype="float32", align=16) - output_buffer = T.match_buffer(output, [1, 3, 24, 12], dtype="float32", align=16) + data_buffer = T.match_buffer(data, [864], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [864], dtype="float32", align=16) # body sid_11 = T.allocate([3456], "int8", "global.workspace") sid_5 = T.allocate([3456], "int8", "global.workspace") diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 404832f814a4..7d0aad49f7f7 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -75,7 +75,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -88,7 +88,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -108,7 +108,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -155,7 +155,7 @@ def run_model(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_ @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16") + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -173,7 +173,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16") + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -185,7 +185,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8") + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -252,7 +252,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -264,7 +264,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -285,7 +285,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -308,7 +308,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") - T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -347,7 +347,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): @@ -370,7 +370,7 @@ class ResnetStructurePlanned: def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): @@ -382,7 +382,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") - T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') @@ -405,7 +405,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") - T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") @@ -428,7 +428,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") - T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") @@ -450,7 +450,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") - T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index fa70cec9de4f..e1541021981a 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -33,7 +33,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -46,7 +46,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -66,7 +66,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index a6f22adc0858..b3243b9a0599 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -36,7 +36,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) B_1 = T.match_buffer(B, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body T.realize(packedB[0:32, 0:1024, 0:32], "") for x in T.parallel(0, 32): @@ -90,7 +90,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) # body packedB = T.allocate([32768], "float32", "global") for x in T.parallel(0, 32): @@ -224,7 +224,7 @@ def mmult( C_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 1, dtype="handle") T.attr(C_data, "storage_alignment", 128) - C: T.Buffer = T.buffer_decl([1024, 1024], dtype="int32", data=C_data) + C: T.Buffer = T.buffer_decl([1024 * 1024], dtype="int32", data=C_data) buf2_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 2, dtype="handle") buf2_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf2_shape_data) buf2_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg2, 0, 3, dtype="handle") @@ -2472,7 +2472,7 @@ def vthread_func(): @T.prim_func def vthread_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") + C = T.match_buffer(c, [256], "float32") i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -2776,7 +2776,7 @@ def rank0_block(): def rank0_block(a: T.handle) -> None: A = T.match_buffer(a, (), "float32") B = T.alloc_buffer((), "float32") - B[0] = A[0] + B[()] = A[()] with T.block("update") as []: T.reads([A[()]]) @@ -2897,7 +2897,7 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) for ax0_ax1_fused_4 in T.serial(0, 56): From 3b20b42f35ddd9f3af4e68968e73ed93c3f71076 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Feb 2022 12:11:57 -0600 Subject: [PATCH 101/177] Corrected section header in docstring. --- python/tvm/tir/buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 53c0916e599f..e36a99339e48 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -156,8 +156,8 @@ def get_flattened_buffer(self): def offset_of(self, indices): """Determine the offset of the provided indices in the flattened buffer. - Params - ------- + Parameters + ---------- indices : Union[PrimExpr, List[PrimExpr]] The indices of the element in the original buffer. From 63941f59098e93c28aaae9e96a4d79f76a0c0e22 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 9 Feb 2022 15:40:30 -0600 Subject: [PATCH 102/177] Corrected indices size check in CogeGenC. --- src/target/source/codegen_c.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 9ba321808882..435a7c15b972 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -587,7 +587,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); - ICHECK_EQ(load->indices.size(), 0) << "CodeGenC only supports flat memory allocations."; + ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations."; os << "(("; this->PrintType(load->dtype.element_of(), os); os << " *)" << this->GetVarID(load->buffer->data.get()) << " + " From 216fa9a687cbc74f15de6d6c7884ff9dd3c414b5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Feb 2022 09:16:05 -0600 Subject: [PATCH 103/177] Fixed breakage in LowerThreadAllreduce. Since the AllocateNode is rewritten, any buffers that refer to those variables must also be rewritten. --- src/tir/transforms/lower_thread_allreduce.cc | 69 +++++++++++++++++--- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 5d17aae1f2b8..3f344c751cdf 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -98,13 +98,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "shared"; } - return stmt; + return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); } else { return stmt; } @@ -121,15 +119,39 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto it = load_remap_.find(op->buffer.get()); - if (it != load_remap_.end()) { - for (const auto& index : op->indices) { - ICHECK(is_zero(index)); + { + auto it = load_remap_.find(op->buffer.get()); + if (it != load_remap_.end()) { + for (const auto& index : op->indices) { + ICHECK(is_zero(index)); + } + return it->second; + } + } + + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + op = load.get(); + + { + auto it = buf_remap_.find(op->buffer.get()); + if (it != buf_remap_.end()) { + return BufferLoad(it->second, op->indices, op->span); } - return it->second; - } else { - return StmtExprMutator::VisitExpr_(op); } + + { + auto it = var_remap_.find(op->buffer->data.get()); + if (it != var_remap_.end()) { + Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape, + op->buffer->strides, op->buffer->elem_offset, op->buffer->name, + op->buffer->data_alignment, op->buffer->offset_factor, + op->buffer->buffer_type, op->buffer->axis_separators, + op->buffer->span); + buf_remap_[op->buffer.get()] = remapped_buffer; + return BufferLoad(remapped_buffer, op->indices, op->span); + } + } + return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -143,6 +165,27 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto writer = store.CopyOnWrite(); writer->buffer = it->second; + return std::move(store); + } + + { + auto it = buf_remap_.find(store->buffer.get()); + if (it != buf_remap_.end()) { + return BufferStore(it->second, store->value, store->indices, store->span); + } + } + + { + auto it = var_remap_.find(store->buffer->data.get()); + if (it != var_remap_.end()) { + Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape, + store->buffer->strides, store->buffer->elem_offset, + store->buffer->name, store->buffer->data_alignment, + store->buffer->offset_factor, store->buffer->buffer_type, + store->buffer->axis_separators, store->buffer->span); + buf_remap_[store->buffer.get()] = remapped_buffer; + return BufferStore(remapped_buffer, store->value, store->indices, store->span); + } } return std::move(store); @@ -365,6 +408,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Array extents{PrimExpr(1)}; auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0)); alloc_remap_[buffers[i]->data.get()] = node; + var_remap_[buffers[i]->data.get()] = buf->data; warp_allocs_.insert(node.get()); } } else { @@ -407,6 +451,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { alloc_remap_[buffers[idx]->data.get()] = Allocate(shared_bufs[idx]->data, types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); + var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; store_remap_[buffers[idx].get()] = shared_bufs[idx]; } } @@ -629,6 +674,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_map store_remap_; // Allocate remap std::unordered_map alloc_remap_; + // BufferVar remap + std::unordered_map var_remap_; + // Buffer remap + std::unordered_map buf_remap_; // Allocate from warp reductions std::unordered_set warp_allocs_; // Internal analyzer From 0c4194bf4c925a91d91998a6b3da1775182ee268 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Feb 2022 09:43:17 -0600 Subject: [PATCH 104/177] [UnitTests] Replaced Store/Load in CUDA codegen tests. --- tests/python/unittest/test_target_codegen_cuda.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 305f82558edc..1227ed8e88d7 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -280,17 +280,17 @@ def vectorizer(op): all_ones = tvm.tir.const(1, "int32x4") store = op.body value = store.value - new_a = tvm.tir.Load("int32x4", value.a.buffer_var, idx, all_ones) + new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) bs, ids = [], [] for i in range(4): bs.append( - tvm.tir.Load( - "int32", value.b.buffer_var, thrx.var * four + tvm.tir.const(i, "int32") + tvm.tir.BufferLoad( + value.b.buffer, [thrx.var * four + tvm.tir.const(i, "int32")] ) ) ids.append(tvm.tir.const(3 - i, "int32")) new_b = tvm.tir.Shuffle(bs, ids) - return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) + return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) return None def _transform(f, *_): From 5329a05cdc7864f0aecd4d90de6e783c82f5417a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Feb 2022 15:46:47 -0600 Subject: [PATCH 105/177] Resolved breakage in C-based codegen for vectorized store/load. Needed to update to new convention of using the buffer's element type as the stride. --- src/target/source/codegen_c.cc | 129 +++++++----------- src/target/source/codegen_c.h | 6 +- src/target/source/codegen_opencl.cc | 13 +- src/target/source/codegen_opencl.h | 6 +- .../unittest/test_target_codegen_cuda.py | 26 ++-- 5 files changed, 76 insertions(+), 104 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 435a7c15b972..ce25542b7c17 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -158,78 +158,52 @@ void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, } // Print a reference expression to a buffer. -std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) { + const VarNode* buffer_var = buffer->data.get(); std::ostringstream os; - std::string vid = GetVarID(buffer); + std::string vid = GetVarID(buffer_var); std::string scope; - if (alloc_storage_scope_.count(buffer)) { - scope = alloc_storage_scope_.at(buffer); + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); } - bool is_vol = IsVolatile(buffer); - if (t.lanes() == 1) { - if (!HandleTypeMatch(buffer, t) || is_vol) { - os << "(("; - if (is_vol) { - os << "volatile "; - } - // Scope may not be part of type. - if (!scope.empty() && IsScopePartOfType()) { - PrintStorageScope(scope, os); - } - PrintType(t, os); - os << "*)" << vid << ')'; - } else { - os << vid; - } - os << "[("; - PrintExpr(index, os); - os << ")"; - if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - os << " / " << (32 / t.bits()); - } - os << ']'; - } else { - // Buffer declared as vector type. - // optimize for case where it is in register, - if (HandleTypeMatch(buffer, t) && !is_vol) { - // optimize for constant access - if (auto* ptr = index.as()) { - int64_t offset = ptr->value; - ICHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; - os << vid << '[' << (offset / t.lanes()) << ']'; - return os.str(); - } - } - os << "(("; + bool is_vol = IsVolatile(buffer_var); + + auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { + std::ostringstream ptr_os; + ptr_os << "("; if (is_vol) { - os << "volatile "; + ptr_os << "volatile "; } if (!scope.empty() && IsScopePartOfType()) { - PrintStorageScope(scope, os); - } - PrintType(t, os); - os << "*)("; - if (!HandleTypeMatch(buffer, t.element_of())) { - os << '('; - if (!scope.empty() && IsScopePartOfType()) { - PrintStorageScope(scope, os); - } - PrintType(t.element_of(), os); - os << "*)"; - } - if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - os << vid << ") + ("; - PrintExpr(index, os); - os << ")"; - os << " / " << t.lanes(); - os << ")[0]"; - } else { - os << vid << " + ("; - PrintExpr(index, os); - os << ")"; - os << "))[0]"; + PrintStorageScope(scope, ptr_os); } + PrintType(pointed_to, ptr_os); + ptr_os << "*)"; + return ptr_os.str(); + }; + + DataType buffer_element_dtype = buffer->dtype; + + std::string buffer_str = vid; + if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { + std::stringstream temp; + temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; + buffer_str = temp.str(); } + + std::string index_str = PrintExpr(index); + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { + std::stringstream temp; + temp << "(" << index_str << ") / " << (32 / t.bits()); + index_str = temp.str(); + } + + if (t == buffer_element_dtype) { + os << buffer_str << "[" << index_str << "]"; + } else { + os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; + } + return os.str(); } @@ -333,11 +307,11 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i, stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec; } -std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenC::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { return GetBufferRef(t, buffer, base); } -void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, +void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); @@ -656,13 +630,15 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; int lanes = op->dtype.lanes(); // delcare type. - if (op->dtype.lanes() == 1) { - std::string ref = GetBufferRef(op->dtype, buffer_var.get(), index); + if (value_dtype.lanes() == element_dtype.lanes()) { + std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); HandleVolatileLoads(ref, op, os); } else { bool can_vector_load = false; @@ -678,7 +654,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI } if (can_vector_load) { - std::string ref = GetVecLoad(op->dtype, buffer_var.get(), base.Eval()); + std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); HandleVolatileLoads(ref, op, os); } else { std::ostringstream svalue_expr; @@ -717,21 +693,22 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { void CodeGenC::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; - DataType t = op->value.dtype(); + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; PrimExpr index_expr = op->indices[0]; Var buffer_var = op->buffer->data; - if (t.lanes() == 1) { + if (value_dtype.lanes() == element_dtype.lanes()) { std::string value = this->PrintExpr(op->value); - std::string ref = this->GetBufferRef(t, buffer_var.get(), index_expr); + std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } else { arith::PVar base; - if (arith::ramp(base, 1, t.lanes()).Match(index_expr)) { + if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) { std::string value = this->PrintExpr(op->value); - this->PrintVecStore(buffer_var.get(), t, base.Eval(), value); + this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value); } else { // The assignment below introduces side-effect, and the resulting value cannot // be reused across multiple expression, thus a new scope is needed @@ -741,9 +718,9 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype()); std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); std::string vid = GetVarID(buffer_var.get()); - for (int i = 0; i < t.lanes(); ++i) { + for (int i = 0; i < value_dtype.lanes(); ++i) { this->PrintIndent(); - DataType elem_type = t.element_of(); + DataType elem_type = value_dtype; if (!HandleTypeMatch(buffer_var.get(), elem_type)) { stream << "(("; if (buffer_var.get()->dtype.is_handle()) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 3536b74b5636..1532b6f358cb 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -177,9 +177,9 @@ class CodeGenC : public ExprFunctor, virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) // print vector load - virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base); + virtual std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base); // print vector store - virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, + virtual void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value); // NOLINT(*) // print load of single element virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i, @@ -198,7 +198,7 @@ class CodeGenC : public ExprFunctor, // Print reference to struct location std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // Print reference to a buffer as type t in index. - virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index); + virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index); /*! * \brief Handle volatile loads. diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 8d0179c183f2..16055de9e634 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -260,21 +260,22 @@ void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(* } } -void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, std::ostream& os) { // NOLINT(*) - if (!HandleTypeMatch(buffer, t.element_of())) { + const VarNode* buffer_var = buffer->data.get(); + if (!HandleTypeMatch(buffer_var, t.element_of())) { os << '('; - auto it = alloc_storage_scope_.find(buffer); + auto it = alloc_storage_scope_.find(buffer_var); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } PrintType(t.element_of(), os); os << "*)"; } - os << GetVarID(buffer) << " + "; + os << GetVarID(buffer_var) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -282,7 +283,7 @@ std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExp return os.str(); } -void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index c72875e8561f..3ae11c69a34e 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -47,11 +47,11 @@ class CodeGenOpenCL final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; - void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, + std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final; + void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, std::ostream& os); // NOLINT(*) void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 1227ed8e88d7..220722b5644f 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -275,20 +275,14 @@ def test_cuda_shuffle(): def MyVectorize(): def vectorizer(op): if op.kind == tvm.tir.ForKind.VECTORIZED: - four = tvm.tir.const(4, "int32") - idx = tvm.tir.Ramp(thrx.var * four, tvm.tir.const(1, "int32"), 4) - all_ones = tvm.tir.const(1, "int32x4") + idx = tvm.tir.Ramp(4 * thrx.var, 1, 4) store = op.body value = store.value new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) bs, ids = [], [] for i in range(4): - bs.append( - tvm.tir.BufferLoad( - value.b.buffer, [thrx.var * four + tvm.tir.const(i, "int32")] - ) - ) - ids.append(tvm.tir.const(3 - i, "int32")) + bs.append(tvm.tir.BufferLoad(value.b.buffer, [4 * thrx.var + i])) + ids.append(3 - i) new_b = tvm.tir.Shuffle(bs, ids) return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) return None @@ -808,21 +802,21 @@ def vcf_check_common(s, args): inside_broadcast = [False] # Possible patterns: - # Reduce init: Store[Ramp] = Broadcast(0) - # Shared memory copy: Store[Ramp] = Load[Ramp] - # Compute: Store[Ramp] = Load[Ramp] ... Broadcast[Load] + # Reduce init: BufferStore[Ramp] = Broadcast(0) + # Shared memory copy: BufferStore[Ramp] = BufferLoad[Ramp] + # Compute: BufferStore[Ramp] = BufferLoad[Ramp] ... Broadcast[Load] def pre_visit(stmt): if isinstance(stmt, tvm.tir.Broadcast): inside_broadcast[0] = True # Check Broadcast[Imm numbers] or Broadcast[Load] patterns - assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load)) + assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.BufferLoad)) if isinstance(stmt, tvm.tir.Store): # Check Store[Ramp] pattern assert isinstance(stmt.index, tvm.tir.Ramp) - if isinstance(stmt, tvm.tir.Load): - # Check Broadcast[Load] or Load[Ramp] patterns - assert inside_broadcast[0] or isinstance(stmt.index, tvm.tir.Ramp) + if isinstance(stmt, tvm.tir.BufferLoad): + # Check Broadcast[BufferLoad] or BufferLoad[Ramp] patterns + assert inside_broadcast[0] or isinstance(stmt.indices[-1], tvm.tir.Ramp) # Skip the rest return stmt return None From 53c0362e57e2f4bf9ca5a7afdf11f4d1055ceb6c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 11 Feb 2022 14:40:16 -0600 Subject: [PATCH 106/177] Bugfix, incorrect LCA for buffer access in root scope. This had been present before the BufferLoad/BufferStore changes, but hadn't triggered on tests using Load/Store nodes. --- src/tir/analysis/buffer_access_lca_detector.cc | 15 +++++++++++++-- src/tir/transforms/compact_buffer_region.cc | 3 ++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index c004c86fe77a..b71e6b27f486 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -43,6 +43,13 @@ class LCADetector : public StmtExprVisitor { detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); } + // The root node must be explicitly present in the list of + // ancestor_scopes_. We cannot use nullptr to represent the root + // node, as that is also used to represent a scope that hasn't + // been observed before. + ScopeInfo root(nullptr, nullptr, 0); + detector.ancestor_scopes_.push_back(&root); + detector(func->body); // Prepare the return Map> buffer_lca; @@ -135,6 +142,7 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { + buffer_var_map_.emplace(buffer->data.get(), buffer); if (match_buffers_.find(buffer) == match_buffers_.end()) { // Ingore buffer created by block match_buffer const ScopeInfo*& lca = buffer_lca_[buffer]; @@ -167,8 +175,11 @@ class LCADetector : public StmtExprVisitor { return lhs; } - /*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */ - std::vector ancestor_scopes_ = {nullptr}; + /*! \brief The ancestor scope stacks info (Block and For). The + * first element is initialized in LCADetector::Detect to represent + * the root scope. + */ + std::vector ancestor_scopes_ = {}; /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */ std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 2970f81cccca..6a317397d6ea 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -215,7 +215,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { continue; } auto dom_it = dom_map_.find(v); - ICHECK(dom_it != dom_map_.end()); + ICHECK(dom_it != dom_map_.end()) + << "Could not find domain for loop variable " << v->name_hint; non_relaxed[i] = dom_it->second; dom_map_.erase(dom_it); } From c636e9b97bd01874ecb0cf966dbefb2e608c3433 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 11 Feb 2022 16:17:08 -0600 Subject: [PATCH 107/177] Added docstrings for TransformNode member variables. --- include/tvm/te/schedule.h | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index deafb3f929ee..8e637b43b52e 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -840,9 +840,34 @@ class Singleton : public IterVarRelation { */ class TransformNode : public IterVarRelationNode { public: + /*! \brief The loop variables that were replaced by the transformation. + * + * Prior to applying a layout transformation, these represent the + * loops to iterate over a tensor as it is being computed, following + * a row-major traversal of the tensor's original shape in the + * compute definition. + */ Array original_variables; + + /*! \brief The variables generated by the transformation. + * + * After to applying a layout transformation, these represent the + * loops to iterate over a tensor as it is being computed, following + * a row-major traversal of the transformed shape of the tensor. + */ Array transformed_variables; + + /*! \brief Map from the original variables to the transformed variables. + * + * Used to determine iterator ranges over the transformed variables. + */ IndexMap forward_transformation; + + /*! \brief Map from transformed variables to the original variables + * + * Used to rewrite expressions containing the original loop iterators + * in terms of the transformed loop iterators. + */ IndexMap inverse_transformation; void VisitAttrs(AttrVisitor* v) { From e40414fc81d99c235f857ae9c741a4f25d072f79 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 11 Feb 2022 16:38:00 -0600 Subject: [PATCH 108/177] Added TODO for future removal of preflattened_buffer_map. --- include/tvm/tir/function.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 08691a889e13..049d351714a9 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -99,6 +99,13 @@ class PrimFuncNode : public BaseFuncNode { * in `preflattened_buffer_map` is assumed to be the same before * and after flattening (e.g. a 1-d tensor that is backed by 1-d * flat memory). + * + * TODO(Lunderberg): Remove preflattened_buffer_map, and instead + * declare each flattened buffer as aliasing the original tensor + * shape. This should include improving the StmtExprMutator to + * provide easier interactions with Buffer objects, so that the + * bookkeeping of relationships between buffers doesn't need to be + * repeated across several transforms. */ Map preflattened_buffer_map; From 27552d64e0db6df95cf0574f03c353d63215b544 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 14 Feb 2022 13:01:23 -0600 Subject: [PATCH 109/177] Fixup, transform layout + cache write tests. The correct sequence is to first apply any caching as needed, then to apply layout transformations, and finally to apply thread binds for the computation step. --- .../python/unittest/test_transform_layout.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index c70be6f782eb..2b29ac498ed2 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -425,22 +425,36 @@ class TestTransformCache: cache_B = tvm.testing.parameter(by_dict={"cacheB": True, "": False}) @tvm.testing.fixture - def schedule_args(self, A_size, transform_A, transform_B, cache_A, cache_B, dtype): + def schedule_args(self, target, A_size, transform_A, transform_B, cache_A, cache_B, dtype): A = te.placeholder(shape=[A_size], dtype=dtype, name="A") B = te.compute(A.shape, lambda i: A[i], name="B") s = te.create_schedule(B.op) - if transform_A: - A_axis = s[A].transform_layout(lambda i: [i // 4, i % 4]) - - if transform_B: - B_axis = s[B].transform_layout(lambda i: [i // 4, i % 4]) + requires_thread_bind = "gpu" in tvm.target.Target(target).keys + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") + thread_z = te.thread_axis("threadIdx.z") if cache_A: AA = s.cache_read(A, "shared", [B]) + if requires_thread_bind: + s[AA].bind(AA.op.axis[0], thread_x) if cache_B: BB = s.cache_write(B, "shared") + if requires_thread_bind: + s[BB].bind(BB.op.axis[0], thread_y) + + if transform_A: + A_axis = s[A].transform_layout(lambda i: [i // 4, i % 4]) + + if transform_B: + B_axis = s[B].transform_layout(lambda i: [i // 4, i % 4]) + else: + B_axis = B.op.axis + + if requires_thread_bind: + s[B].bind(B_axis[0], thread_z) return [s, [A, B]] From b85b4eedd435c421729bcd6bd81264d98f2740c0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 14 Feb 2022 16:20:04 -0600 Subject: [PATCH 110/177] Bugfix, correct element type for scalarized access. --- src/target/source/codegen_c.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ce25542b7c17..fb94ca287c51 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -720,7 +720,7 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { std::string vid = GetVarID(buffer_var.get()); for (int i = 0; i < value_dtype.lanes(); ++i) { this->PrintIndent(); - DataType elem_type = value_dtype; + DataType elem_type = value_dtype.element_of(); if (!HandleTypeMatch(buffer_var.get(), elem_type)) { stream << "(("; if (buffer_var.get()->dtype.is_handle()) { From a8b5fa37e86b882dbca8874a291ec9f6c9e217fd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 12:41:00 -0600 Subject: [PATCH 111/177] Bugfix, cuda buffer indexing when declared as different type. --- src/target/source/codegen_c.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index fb94ca287c51..0ae514439e17 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -193,12 +193,14 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp std::string index_str = PrintExpr(index); if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { - std::stringstream temp; - temp << "(" << index_str << ") / " << (32 / t.bits()); - index_str = temp.str(); - } - - if (t == buffer_element_dtype) { + // This is a special case, because CodegenCUDA::PrintType() + // returns "int" for bool and for 4-bit integers. Therefore, we + // need to do the pointer arithmetic in the output's datatype, + // rather than the buffer's element type. + os << "*(" + << "(" << ptr_cast(t) << vid << ")" + << " + " << index_str << " / " << t.lanes() << ")"; + } else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; } else { os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; From e2342dcaa5b05a22722fa3516b36c9a907f77274 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 13:16:49 -0600 Subject: [PATCH 112/177] Cuda codegen, update reference. --- tests/python/unittest/test_target_codegen_cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 220722b5644f..483fdf581172 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -1031,7 +1031,7 @@ def build(A, C, N, C_N): N, C_N, A, C = get_compute_aligned() a_data, c, kernel_source = build(A, C, N, C_N) # (uint1*)(A + (2)) is a valid vector load - assert "A + (2)" in kernel_source + assert "A + 2" in kernel_source expected = a_data[2 : C_N + 2] assert np.allclose(c, expected), f"expected={expected}\nactual={c}" From 70d9d3c673ac6ec88621ec3d741f3bd5f9ee7b61 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 15:08:01 -0600 Subject: [PATCH 113/177] Bugfix, lower allreduce Loads of the output of the reduction should be replaced for all buffers sharing a buffer pointer, not just for the buffer object itself. --- src/tir/transforms/lower_thread_allreduce.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 3f344c751cdf..ce4dbcac557d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { { - auto it = load_remap_.find(op->buffer.get()); + auto it = load_remap_.find(op->buffer->data.get()); if (it != load_remap_.end()) { for (const auto& index : op->indices) { ICHECK(is_zero(index)); @@ -398,12 +398,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Update existing allocations. for (size_t i = 0; i < size; ++i) { - ICHECK(!load_remap_.count(buffers[i].get())); + ICHECK(!load_remap_.count(buffers[i]->data.get())); PrimExpr pred = const_true(types[i].lanes()); Buffer buf = shared_bufs[i]; PrimExpr val = BufferLoad(buf, zero_indices); ICHECK_EQ(val->dtype, types[i]); - load_remap_[buffers[i].get()] = val; + load_remap_[buffers[i]->data.get()] = val; store_remap_[buffers[i].get()] = buf; Array extents{PrimExpr(1)}; auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0)); @@ -442,12 +442,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, reduce_extent, threadx_extent)); for (size_t idx = 0; idx < size; ++idx) { - ICHECK(!load_remap_.count(buffers[idx].get())); + ICHECK(!load_remap_.count(buffers[idx]->data.get())); PrimExpr pred = const_true(types[idx].lanes()); BufferLoad load(shared_bufs[idx], {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); ICHECK_EQ(load->dtype, types[idx]); - load_remap_[buffers[idx].get()] = load; + load_remap_[buffers[idx]->data.get()] = load; alloc_remap_[buffers[idx]->data.get()] = Allocate(shared_bufs[idx]->data, types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); @@ -669,7 +669,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // The store remap std::unordered_map store_remap_; // Allocate remap From 2e09604d6557f417d871fdc775e49174f7acd008 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 16:19:56 -0600 Subject: [PATCH 114/177] Removed obsolete comment. --- include/tvm/tir/buffer.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 69d6777d87f1..aef82ae368d0 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -115,12 +115,6 @@ class BufferNode : public Object { bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { // Use DefEqual as buffer can define variables in its semantics, // skip name as name is not important. - - // The pre-flattened information is only used for type-checking, - // and doesn't represent a different computation. - // - // TODO(Lunderberg): Move the pre-flattened buffer information - // into the PrimFunc's buffer_map. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && equal.DefEqual(axis_separators, other->axis_separators) && From 2029ced9eda0cef5fd5902333e7780fbeeb113a2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 16:20:12 -0600 Subject: [PATCH 115/177] Changed PrimFunc constructor preflattened_buffer_map to Optional --- include/tvm/tir/function.h | 9 +++++---- src/tir/ir/function.cc | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 049d351714a9..2739eb41ef58 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -179,10 +179,11 @@ class PrimFunc : public BaseFunc { * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), - Map preflattened_buffer_map = Map(), - DictAttrs attrs = NullValue(), Span span = Span()); + TVM_DLL PrimFunc( + Array params, Stmt body, Type ret_type = VoidType(), + Map buffer_map = Map(), + Optional> preflattened_buffer_map = Optional>(), + DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 058f350059cd..b650b4f5aa09 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -37,8 +37,9 @@ LinkedParam::LinkedParam(int64_t id, ::tvm::runtime::NDArray param) { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, Map preflattened_buffer_map, - DictAttrs attrs, Span span) { + Map buffer_map, + Optional> preflattened_buffer_map, DictAttrs attrs, + Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -49,7 +50,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->body = std::move(body); n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); - n->preflattened_buffer_map = std::move(preflattened_buffer_map); + n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map()); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); n->span = std::move(span); From c8f9015cf9b33066cf8a87662f56a385bc33127e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 16:25:46 -0600 Subject: [PATCH 116/177] Removed flatten_buffer argument from T.match_buffer. --- python/tvm/script/tir/special_stmt.py | 8 +------- tests/python/unittest/test_lower_build.py | 24 +++++++++++++++-------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 05497b975127..d9c6dbda47b2 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -132,7 +132,6 @@ def match_buffer( align=-1, offset_factor=0, buffer_type="default", - flatten_buffer=False, span=None, ): if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: @@ -166,12 +165,7 @@ def match_buffer( self.context.report_error( "Can not bind non-input param to buffer", self.node.rhs.params[0].span ) - if flatten_buffer: - self.context.func_preflattened_buffer_map[param] = buffer - buffer = buffer.get_flattened_buffer() - self.context.func_buffer_map[param] = buffer - else: - self.context.func_buffer_map[param] = buffer + self.context.func_buffer_map[param] = buffer elif isinstance(param, BufferSlice): buffer_region = buffer_slice_to_region(param) self.context.current_block_scope().match_buffers.append( diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 326554e90e5a..4b4eb28c4fca 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -53,12 +53,16 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.script.ir_module class LoweredModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: + def main( + A: T.Buffer[(16384,), "float32"], + B: T.Buffer[(16384,), "float32"], + C: T.Buffer[(16384,), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A = T.match_buffer(a, [128, 128], flatten_buffer=True) - B = T.match_buffer(b, [128, 128], flatten_buffer=True) - C = T.match_buffer(c, [128, 128], flatten_buffer=True) + T.preflattened_buffer(A, [128, 128], data=A.data) + T.preflattened_buffer(B, [128, 128], data=B.data) + T.preflattened_buffer(C, [128, 128], data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = 0.0 @@ -69,12 +73,16 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.script.ir_module class LoweredTIRModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: + def main( + A: T.Buffer[(16384,), "float32"], + B: T.Buffer[(16384,), "float32"], + C: T.Buffer[(16384,), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [128, 128], flatten_buffer=True) - B = T.match_buffer(b, [128, 128], flatten_buffer=True) - C = T.match_buffer(c, [128, 128], flatten_buffer=True) + T.preflattened_buffer(A, [128, 128], data=A.data) + T.preflattened_buffer(B, [128, 128], data=B.data) + T.preflattened_buffer(C, [128, 128], data=C.data) # body for x, y in T.grid(128, 128): C[x * 128 + y] = 0.0 From 8f971592c717a3d7fe159dff9368b928ce6282be Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 16:30:37 -0600 Subject: [PATCH 117/177] Correct call to VarUseDefAnalysis::VisitBuffer --- src/tir/transforms/split_host_device.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 4f9530b93fda..a0e117e3157e 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -112,7 +112,7 @@ class VarUseDefAnalysis : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode* op) final { - this->HandleUse(op->buffer->data); + VisitBuffer(op->buffer); return StmtExprMutator::VisitStmt_(op); } @@ -165,7 +165,7 @@ class VarUseDefAnalysis : public StmtExprMutator { } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - this->HandleUse(op->buffer->data); + VisitBuffer(op->buffer); return StmtExprMutator::VisitExpr_(op); } From e3e3d896e70141d72cf9a8bb6b8553be13d3e15e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Feb 2022 16:38:10 -0600 Subject: [PATCH 118/177] Reverted unintentional testing change, lanes=2. --- .../unittest/test_tir_transform_instrument_bound_checkers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index 279faa54d830..9f61b5a3920a 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -94,7 +94,7 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): @tvm.testing.requires_llvm def test_in_bounds_vectorize_llvm(): n = 512 - lanes = 1 + lanes = 2 A = te.placeholder((n,), name="A", dtype="float32x%d" % lanes) B = te.compute((n,), lambda i: A[i], name="B") C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C") From d8b88a90201d2887c95a7add60f3e5776931224b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 16 Feb 2022 11:13:30 -0600 Subject: [PATCH 119/177] Updated lower_cross_thread_reduction to use buffer in allreduce --- src/tir/transforms/lower_cross_thread_reduction.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 4df38ff543b5..df8bf69e7468 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -314,7 +314,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionaldata); + parameters.push_back(BufferLoad(ct_buffer, {0})); // next arguments: all the reduction threads for (const ForNode* reduction_loop : reduction_loops) { if (reduction_loop->thread_binding.defined()) { From fa941c973466146c9036e1d5b684bc6b06ac54d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 16 Feb 2022 11:23:46 -0600 Subject: [PATCH 120/177] Updated transform_layout test to disable CSE --- tests/python/unittest/test_transform_layout.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 2b29ac498ed2..5cac01dd7f7c 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -258,7 +258,12 @@ def test_2d_physical(self, dtype, transform_A, transform_B): if transform_B: s[B].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k]) - mod = tvm.lower(s, [A, B]) + # If the two buffers are accessed with the same indices, CSE + # will replace them with a Let binding. Since this makes it + # harder to test what the transformed indices are, disabling + # the CSE pass for this test. + with tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]): + mod = tvm.lower(s, [A, B]) i, j, k = self.extract_loop_vars(mod["main"].body) indices_1d = [i * (logical_shape[1] * logical_shape[2]) + j * logical_shape[2] + k] From fb14c5e456659652559592215605e024007b214b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 16 Feb 2022 11:38:22 -0600 Subject: [PATCH 121/177] Updated CSE unit tests to use BufferStore --- .../test_tir_transform_common_subexpr_elim.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index b01a9e652f77..17c0cbdd99c6 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -45,7 +45,7 @@ def test_cse(): 2, tvm.tir.SeqStmt( [ - tvm.tir.Store(buffer.data, z1 + z2, i1), + tvm.tir.BufferStore(buffer, z1 + z2, [i1]), tvm.tir.LetStmt( x, 1, @@ -56,7 +56,7 @@ def test_cse(): a, (x + y) + (z1 + z2), tvm.tir.LetStmt( - b, (x + y) + z3, tvm.tir.Store(buffer.data, a + b, i2) + b, (x + y) + z3, tvm.tir.BufferStore(buffer, a + b, [i2]) ), ), ), @@ -96,7 +96,7 @@ def test_cse(): body = body.body - assert isinstance(body[0], tvm.tir.Store) + assert isinstance(body[0], tvm.tir.BufferStore) assert isinstance(body[1], tvm.tir.LetStmt) body = body[1] @@ -130,7 +130,7 @@ def test_cse(): # Check that the replacement has been done correctly! assert tvm.ir.structural_equal(body.value, cse_var_2 + z3) - assert isinstance(body.body, tvm.tir.Store) + assert isinstance(body.body, tvm.tir.BufferStore) # First specific test for if nodes : Some duplicated computations appear only in one branch (here the Then branch), not in both branches. @@ -160,9 +160,9 @@ def test_cse_ifNode_1(): tvm.tir.IfThenElse( b, tvm.tir.SeqStmt( - [tvm.tir.Store(buffer.data, y + z, i1), tvm.tir.Store(buffer.data, y + z, i2)] + [tvm.tir.BufferStore(buffer, y + z, [i1]), tvm.tir.BufferStore(buffer, y + z, [i2])] ), - tvm.tir.Store(buffer.data, y, i3), + tvm.tir.BufferStore(buffer, y, [i3]), ), ) @@ -217,11 +217,11 @@ def test_cse_ifNode_2(): b, tvm.tir.SeqStmt( [ - tvm.tir.Store(buffer.data, y + z, i1), # (y+z) is present in the Then branch - tvm.tir.Store(buffer.data, y, i2), + tvm.tir.BufferStore(buffer, y + z, [i1]), # (y+z) is present in the Then branch + tvm.tir.BufferStore(buffer, y, [i2]), ] ), - tvm.tir.Store(buffer.data, y + z, i3), # and also present in the Else branch + tvm.tir.BufferStore(buffer, y + z, [i3]), # and also present in the Else branch ), ) @@ -258,9 +258,9 @@ def test_cse_cascade(): # Mem[i3] = x+y body = tvm.tir.SeqStmt( [ - tvm.tir.Store(buffer.data, (x + y) + z, i1), - tvm.tir.Store(buffer.data, (x + y) + z, i2), - tvm.tir.Store(buffer.data, (x + y), i3), + tvm.tir.BufferStore(buffer, (x + y) + z, [i1]), + tvm.tir.BufferStore(buffer, (x + y) + z, [i2]), + tvm.tir.BufferStore(buffer, (x + y), [i3]), ] ) @@ -292,9 +292,9 @@ def test_cse_cascade(): body = body.body assert isinstance(body, tvm.tir.SeqStmt) - assert isinstance(body[0], tvm.tir.Store) - assert isinstance(body[1], tvm.tir.Store) - assert isinstance(body[2], tvm.tir.Store) + assert isinstance(body[0], tvm.tir.BufferStore) + assert isinstance(body[1], tvm.tir.BufferStore) + assert isinstance(body[2], tvm.tir.BufferStore) store1 = body[0] store2 = body[1] From 120bb5b3f3da4bf9b2ffc0c13f77cfac450108c2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 16 Feb 2022 13:59:48 -0600 Subject: [PATCH 122/177] Replaced Store/Load for vta.transform and unit tests. --- vta/python/vta/transform.py | 97 +++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 383841f19e34..1e8247c6e135 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -156,15 +156,45 @@ def CPUAccessRewrite(): """ def _ftransform(f, mod, ctx): - rw_info = {} env = get_env() + var_remap = {} + buf_remap = {} + + def find_var_remap(old_var): + if old_var in var_remap: + return var_remap[old_var] + + new_var = tvm.tir.Var(old_var.name + "_ptr", dtype=old_var.type_annotation) + var_remap[old_var] = new_var + return new_var + + def find_buf_remap(old_buf): + if old_buf in buf_remap: + return buf_remap[old_buf] + + new_var = find_var_remap(old_buf.data) + new_buf = tvm.tir.decl_buffer( + shape=old_buf.shape, + dtype=old_buf.dtype, + data=new_var, + strides=old_buf.strides, + elem_offset=old_buf.elem_offset, + scope=old_buf.scope, + data_alignment=old_buf.data_alignment, + offset_factor=old_buf.offset_factor, + buffer_type="auto_broadcast" if (old_buf.buffer_type == 2) else "", + axis_separators=old_buf.axis_separators, + ) + buf_remap[old_buf] = new_buf + return new_buf + def _post_order(op): if isinstance(op, tvm.tir.Allocate): buffer_var = op.buffer_var - if not buffer_var in rw_info: + if buffer_var not in var_remap: return None - new_var = rw_info[buffer_var] + new_var = var_remap[buffer_var] let_stmt = tvm.tir.LetStmt( new_var, tvm.tir.call_extern( @@ -173,33 +203,31 @@ def _post_order(op): op.body, ) alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) - del rw_info[buffer_var] + del var_remap[buffer_var] + bufs_to_delete = [ + old_buf for old_buf in buf_remap if old_buf.data.same_as(buffer_var) + ] + for buf in bufs_to_delete: + del buf_remap[buf] return alloc - if isinstance(op, tvm.tir.Load): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var(buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Load(op.dtype, new_var, op.index) - if isinstance(op, tvm.tir.Store): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var(buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Store(new_var, op.value, op.index) + + if isinstance(op, tvm.tir.BufferLoad): + return tvm.tir.BufferLoad(find_buf_remap(op.buffer), op.indices) + + if isinstance(op, tvm.tir.BufferStore): + return tvm.tir.BufferStore(find_buf_remap(op.buffer), op.value, op.indices) + raise RuntimeError("not reached") stmt_in = f.body stmt = tvm.tir.stmt_functor.ir_transform( - stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"] + stmt_in, None, _post_order, ["tir.Allocate", "tir.BufferLoad", "tir.BufferStore"] ) - for buffer_var, new_var in rw_info.items(): + for old_var, new_var in var_remap.items(): stmt = tvm.tir.LetStmt( new_var, - tvm.tir.call_extern( - "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var - ), + tvm.tir.call_extern("handle", "VTABufferCPUPtr", env.dev.command_handle, old_var), stmt, ) return f.with_body(stmt) @@ -919,8 +947,8 @@ def _flatten_loop(src_coeff, dst_coeff, extents): loop_body = loop_body.body nest_size += 1 # Get the src/dst arguments - dst_var = loop_body.buffer_var - dst_idx = loop_body.index + dst_var = loop_body.buffer.data + dst_idx = loop_body.indices[0] # Derive loop variables and extents tmp_body = stmt.body indices = [] @@ -963,7 +991,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): raise RuntimeError( "Function call not recognized %s" % (loop_body.value.name) ) - elif isinstance(loop_body.value, tvm.tir.Load): + elif isinstance(loop_body.value, tvm.tir.BufferLoad): alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value rhs = tvm.tir.const(0, "int32") @@ -979,20 +1007,20 @@ def _flatten_loop(src_coeff, dst_coeff, extents): use_imm = False imm_val = None if isinstance(rhs, tvm.tir.IntImm): - assert lhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + assert lhs.buffer.data.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(lhs.indices[0], indices) use_imm = True imm_val = rhs if isinstance(lhs, tvm.tir.IntImm): - assert rhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + assert rhs.buffer.data.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(rhs.indices[0], indices) use_imm = True imm_val = lhs if imm_val is None: imm_val = 0 - assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) - src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) - src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + assert lhs.buffer.data.same_as(dst_var) and rhs.buffer.data.same_as(dst_var) + src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.indices[0], indices) + src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.indices[0], indices) # Determine which side has the same coefficients lhs_equal = True rhs_equal = True @@ -1058,7 +1086,12 @@ def _flatten_loop(src_coeff, dst_coeff, extents): for idx, extent in enumerate(extents): irb.emit( tvm.tir.call_extern( - "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0 + "int32", + "VTAUopLoopBegin", + extent, + dst_coeff[idx], + src_coeff[idx], + 0, ) ) use_imm = int(use_imm) From e4c169d4774c4e59f8d130c4c4a6a6b39f38c42f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 17 Feb 2022 08:56:07 -0600 Subject: [PATCH 123/177] Updated unit tests for lower_cross_thread_reduction. --- ..._transform_lower_cross_thread_reduction.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 5b3d7283f14f..e2e688aac1bf 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -25,7 +25,14 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(original) mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) + try: + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + except ValueError: + with open("temp_expected.txt", "w") as f: + f.write(transformed.script()) + with open("temp_observed.txt", "w") as f: + f.write(mod["main"].script()) + raise def _check_fail(original): @@ -82,7 +89,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], ki, dtype="handle", ) @@ -127,7 +134,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: ) T.evaluate( T.tvm_thread_allreduce( - T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + T.uint32(1), A[vi, vk], True, reduce_temp0[0], k, dtype="handle" ) ) with T.block("B_write_back"): @@ -174,7 +181,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: ) T.evaluate( T.tvm_thread_allreduce( - T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle" + T.uint32(1), A[vi, vk], True, reduce_temp0[0], ko, ki, dtype="handle" ) ) with T.block("B_write_back"): @@ -253,7 +260,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], k0o, dtype="handle", ) @@ -315,7 +322,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], ki, dtype="handle", ) @@ -418,7 +425,7 @@ def lowered_single_reduction_loop_with_block_predicate( T.uint32(1), in_thread_0[0], True, - cross_thread_0.data, + cross_thread_0[0], ax1_1, dtype="handle", ) @@ -456,7 +463,7 @@ def lowered_single_reduction_loop_with_block_predicate( T.uint32(1), in_thread_1[0], True, - cross_thread_1.data, + cross_thread_1[0], ax1_1, dtype="handle", ) @@ -516,7 +523,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None: ) T.evaluate( T.tvm_thread_allreduce( - T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + T.uint32(1), A[vi, vk], True, reduce_temp0[0], k, dtype="handle" ) ) with T.block("B_write_back"): @@ -556,9 +563,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( - T.tvm_thread_allreduce( - T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle" - ) + T.tvm_thread_allreduce(T.uint32(1), A[vk], True, reduce_temp0[0], k, dtype="handle") ) with T.block("B_write_back"): T.reads([reduce_temp0[0]]) @@ -746,7 +751,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: T.uint32(1), normal_reduce_temp0[0], True, - reduce_temp0.data, + reduce_temp0[0], ax0_1, dtype="handle", ) @@ -789,7 +794,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: T.uint32(1), normal_reduce_temp1[0], True, - reduce_temp1.data, + reduce_temp1[0], ax0_1, dtype="handle", ) From 64258824fa80067f0dd6b4bd2284fce596e9c8ed Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 17 Feb 2022 15:13:57 -0600 Subject: [PATCH 124/177] Updated arange to use scalar tensors. The start/stop/step tensors are declared as 0-d scalar tensors, but were accessed as 1-d tensors. --- python/tvm/relay/op/_transform.py | 6 +++--- src/relay/op/tensor/transform.cc | 6 +++++- src/tir/ir/expr.cc | 5 +++++ src/tir/ir/stmt.cc | 5 +++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 25ecf0e2e746..6520bd186d09 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -200,10 +200,10 @@ def compute_unique(attrs, inputs, output_type): @script def _arange_shape_func(start, stop, step): out = output_tensor((1,), "int64") - if step[0] < 0: - out[0] = int64(ceil_div((int64(start[0]) - int64(stop[0])), int64(-step[0]))) + if step[()] < 0: + out[0] = int64(ceil_div((int64(start[()]) - int64(stop[()])), int64(-step[()]))) else: - out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0]))) + out[0] = int64(ceil_div((int64(stop[()]) - int64(start[()])), int64(step[()]))) return out diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b5316f2b7bca..3f7da4e954a4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1571,11 +1571,15 @@ inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, const te::Tensor& step, tvm::DataType dtype, std::string name = "T_arange_dynamic", std::string tag = topi::kInjective) { + ICHECK_EQ(start.ndim(), 0); + ICHECK_EQ(stop.ndim(), 0); + ICHECK_EQ(step.ndim(), 0); tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); return te::compute( {num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start[0] + step[0] * indices[0]); + Array empty_indices; + return tvm::cast(dtype, start(empty_indices) + step(empty_indices) * indices[0]); }, name, tag); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 8f904d5bd8f7..04993603a3dc 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1070,6 +1070,11 @@ void BufferLoadNode::LegalizeDtype() { } BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { + ICHECK_EQ(buffer->shape.size(), indices.size()) + << "Buffer " << buffer->name << " is " << buffer->shape.size() + << "-dimensional, cannot be indexed with the " << indices.size() + << "-dimensional indices provided."; + ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b6eec727e2c2..5a98bfec6ffd 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -597,6 +597,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // BufferStore BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { + ICHECK_EQ(buffer->shape.size(), indices.size()) + << "Buffer " << buffer->name << " is " << buffer->shape.size() + << "-dimensional, cannot be indexed with the " << indices.size() + << "-dimensional indices provided."; + ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->value = std::move(value); From 4d020483ece0e34239b738e0c9018cfa5024e44d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 09:39:49 -0600 Subject: [PATCH 125/177] Fix breakage in ethosu constant encoding. Buffers generated by "ethosu_copy" should have their buffer objects rewritten, but shouldn't have their size updated in ethosu-specific Call nodes. --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index a4c873d19d75..792072a05b63 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -315,6 +315,7 @@ def EncodeConstants(const_dict): new_const_dict = {} buffer_to_const = {} pointer_to_buffer = {} + encoded_buffers = set() rewrite_buffer = {} rewrite_pointer = {} accel_config = vela_api.get_accelerator_config() @@ -348,6 +349,7 @@ def _new_buffer(old_buffer, new_value): rewrite_buffer[old_buffer] = new_buffer rewrite_pointer[old_buffer.data] = new_buffer.data + encoded_buffers.add(new_buffer) def _visit_encode_pre(stmt): if isinstance(stmt, tvm.tir.Call): @@ -416,7 +418,7 @@ def _visit_rewrite(stmt): if old_buffer.data in pointer_to_buffer: new_buffer = pointer_to_buffer[old_buffer.data] # Only rewrite the arguments of buffers that have been encoded - if new_buffer in new_buffers: + if new_buffer in encoded_buffers: new_arg = np.prod(list(new_buffer.shape)) new_args.append(new_arg) continue From 8bf65736cb73f96197ca64f2806f5ecb1efdd7e2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 09:53:15 -0600 Subject: [PATCH 126/177] Fix breakage in ethosu call argument checks. Need to pull out indices from BufferLoad holders, not Load. --- tests/python/contrib/test_ethosu/infra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index d5bd28039feb..42638eaea42a 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -565,8 +565,8 @@ def get_binary_elementwise_args(call, include_buffers=False): for i, arg in enumerate(args): if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): binary_elementwise_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - binary_elementwise_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + binary_elementwise_args.append(arg.indices[0]) else: binary_elementwise_args.append(arg) From 77841ae62acc06c510a3528607d50ba4f7df9b9d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 10:06:31 -0600 Subject: [PATCH 127/177] Resolve breakage from mismatched shape/index dimensions --- .../test_tir_analysis_calculate_workspace.py | 14 ++-- .../unittest/test_tir_lower_match_buffer.py | 2 +- ...orm_convert_pool_allocations_to_offsets.py | 80 +++++++++---------- tests/python/unittest/test_tir_usmp_utils.py | 12 +-- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index e866e996f174..8449782f4589 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -26,10 +26,10 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.handle, placeholder_146: T.handle, T_cast_48: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13", "tir.noalias": True}) - placeholder_147 = T.match_buffer(placeholder_144, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_148 = T.match_buffer(placeholder_145, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_149 = T.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_49 = T.match_buffer(T_cast_48, [1*14*14*512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_147 = T.match_buffer(placeholder_144, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_148 = T.match_buffer(placeholder_145, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_149 = T.match_buffer(placeholder_146, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_49 = T.match_buffer(T_cast_48, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_22 = T.allocate([131072], "int16", "global") DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") @@ -57,9 +57,9 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handle, placeholder_164: T.handle, T_cast_76: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) - placeholder_165 = T.match_buffer(placeholder_162, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_166 = T.match_buffer(placeholder_163, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_167 = T.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_165 = T.match_buffer(placeholder_162, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_166 = T.match_buffer(placeholder_163, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_167 = T.match_buffer(placeholder_164, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_25 = T.allocate([131072], "int16", "global") diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 623cff6420e6..93b7caf9cdde 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -465,7 +465,7 @@ def fail_match_load(a: T.handle) -> None: T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], ()) - T.evaluate(sub_A[0]) + T.evaluate(sub_A[()]) @T.prim_func diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 7d0aad49f7f7..57e63fc1066e 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -73,8 +73,8 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): @@ -85,9 +85,9 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") @@ -107,7 +107,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") @@ -154,7 +154,7 @@ def run_model(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_ @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) @@ -171,8 +171,8 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) @@ -182,9 +182,9 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) @@ -250,7 +250,7 @@ class ResnetStructure: def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") # body @@ -261,9 +261,9 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") @@ -282,9 +282,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") @@ -304,10 +304,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") - placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") - placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") - placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") @@ -344,9 +344,9 @@ def run_model(input: T.handle, output: T.handle) -> None: def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") @@ -368,7 +368,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: - placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) @@ -378,10 +378,10 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: - placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") - placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") - placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") - placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -402,9 +402,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: - placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") - placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") - placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -425,9 +425,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: - placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") - placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") - placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -447,9 +447,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: - placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") - placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") - placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index e1541021981a..e6add3a5cfd3 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -31,8 +31,8 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): @@ -43,9 +43,9 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") @@ -65,7 +65,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") From c20709c94b948ea6bc67a3d33f1f671cdf91e713 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 11:50:28 -0600 Subject: [PATCH 128/177] Split out encoded parameters from preflattened buffer map. --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 792072a05b63..a8b8dc40e293 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -729,6 +729,7 @@ def CreatePrimFuncWithoutConstants(const_dict): def _ftransform(f, mod, ctx): new_params = list() new_buffer_map = dict() + new_preflattened_buffer_map = dict() for param_idx in const_dict.keys(): # We are using buffer_var to key the constants as # PrimFunc params of constants will be removed. @@ -737,12 +738,13 @@ def _ftransform(f, mod, ctx): if i not in const_dict.keys(): new_params.append(f.params[i]) new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] + new_preflattened_buffer_map[f.params[i]] = f.preflattened_buffer_map[f.params[i]] return tvm.tir.PrimFunc( new_params, f.body, f.ret_type, new_buffer_map, - f.preflattened_buffer_map, + new_preflattened_buffer_map, f.attrs, f.span, ) From 521556edac11cac0add717be38885543cb393418 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 16:41:20 -0600 Subject: [PATCH 129/177] Updated buffer shape/index dimensions to match in more ethosu tests --- .../contrib/test_ethosu/test_replace_copy.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index afe7aa81c73f..f952eca8fbd5 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -31,11 +31,11 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([0], "uint8") + buffer_1 = T.buffer_decl([0], "uint8") # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) @@ -75,13 +75,13 @@ def _get_func(): @tvm.script.ir_module class WeightStream: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([0], "uint8") + buffer_1 = T.buffer_decl([0], "uint8") + buffer_2 = T.buffer_decl([0], "uint8") + buffer_3 = T.buffer_decl([0], "uint8") # body placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) From 7a2eb8ee2cd9dd4e219132dd743c178d39ee7d90 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 16:43:26 -0600 Subject: [PATCH 130/177] Fixed lint error --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index a8b8dc40e293..461a6f648722 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -409,7 +409,6 @@ def _visit_rewrite(stmt): # For extern calls, we need to rewrite pairs of arguments corresponding to # base address load and the length of the load. new_args = [stmt.args[0]] - new_buffers = rewrite_buffer.values() for i in range(1, len(stmt.args)): # If the previous argument was a load, the current should be a length if isinstance(stmt.args[i - 1], tvm.tir.BufferLoad): From b08245fed9b5f296723bc660457364e5d82a5a64 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 16:44:20 -0600 Subject: [PATCH 131/177] Removed debug code --- .../test_tir_transform_lower_cross_thread_reduction.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index e2e688aac1bf..2be3bb181150 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -25,14 +25,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(original) mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) - try: - tvm.ir.assert_structural_equal(mod["main"], transformed, True) - except ValueError: - with open("temp_expected.txt", "w") as f: - f.write(transformed.script()) - with open("temp_observed.txt", "w") as f: - f.write(mod["main"].script()) - raise + tvm.ir.assert_structural_equal(mod["main"], transformed, True) def _check_fail(original): From f3d17b20cf656b59be771e94af14216ecad285c6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 16:45:24 -0600 Subject: [PATCH 132/177] Moved arith::Analyzer local variable to class member --- src/tir/transforms/vectorize_loop.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 20f67e0e40b0..dc9df9e348c6 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -91,8 +91,6 @@ class VecAllocAccess : public StmtExprMutator { return node; } - arith::Analyzer analyzer; - // Find/make a Buffer object with the correct updated shape. Buffer buf; auto it = buffer_map_.find(node->buffer.get()); @@ -103,7 +101,7 @@ class VecAllocAccess : public StmtExprMutator { // var_lanes_. Typically, this will be a 1-d index into a flat // memory space. Array shape = node->buffer->shape; - shape.Set(shape.size() - 1, analyzer.Simplify(shape[shape.size() - 1] * var_lanes_)); + shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); // TODO(Lunderberg): Move this pass to be prior to // StorageFlatten/FlattenBuffer, implement by appending a @@ -118,7 +116,7 @@ class VecAllocAccess : public StmtExprMutator { if (i != strides.size() - 1) { stride *= var_lanes_; } - strides.push_back(analyzer.Simplify(stride)); + strides.push_back(analyzer_.Simplify(stride)); } // Copy everything into the new buffer. @@ -133,7 +131,7 @@ class VecAllocAccess : public StmtExprMutator { // variable. Array indices = node->indices; indices.Set(indices.size() - 1, - analyzer.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); + analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); auto writer = node.CopyOnWrite(); writer->buffer = buf; @@ -149,6 +147,8 @@ class VecAllocAccess : public StmtExprMutator { Var var_; // the lanes. int var_lanes_; + // Analyzer for simplifications + arith::Analyzer analyzer_; }; // We use ExprFunctor directly instead of StmtExprMutator From e8aa9d6549c4c03fde7829ab7c65a91692b0d101 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 22 Feb 2022 13:40:50 -0600 Subject: [PATCH 133/177] Fixed SSA conversion of allocations. Can occur if allocation is inside an unrolled loop. Added unit test to catch this failure mode. --- src/tir/transforms/ir_utils.cc | 95 +++++++++++++------ .../test_tir_transform_unroll_loop.py | 25 +++++ 2 files changed, 90 insertions(+), 30 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 7d8b1963c35b..700c9931bba0 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -101,11 +101,9 @@ class IRConvertSSA final : public StmtExprMutator { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - Var new_var(v->name_hint, v.dtype()); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); PrimExpr body = this->VisitExpr(op->body); - scope_[v.get()].pop_back(); - return Let(new_var, value, body); + return Let(redefine.new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitExpr_(op); @@ -124,12 +122,14 @@ class IRConvertSSA final : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - return VisitBufferAccess(std::move(node)); + auto output = VisitBufferAccess(std::move(node)); + return std::move(output); } Stmt VisitStmt_(const BufferStoreNode* op) final { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); + auto output = VisitBufferAccess(std::move(node)); + return std::move(output); } template @@ -144,32 +144,46 @@ class IRConvertSSA final : public StmtExprMutator { } Buffer GetRemappedBuffer(Buffer buf) { - auto key = buf.get(); - auto buf_it = buf_remap_.find(key); - if (buf_it != buf_remap_.end()) { - return buf_it->second; - } - + // Determine the buffer var that should be in the updated buffer, + // given the current scope. If no redefines are present, then the + // buffer var is unchanged. + Var new_buffer_var = buf->data; auto var_it = scope_.find(buf->data.get()); if (var_it != scope_.end() && !var_it->second.empty()) { - Var buffer_var = var_it->second.back(); - auto writer = buf.CopyOnWrite(); - writer->data = buffer_var; + new_buffer_var = var_it->second.back(); + } + + // If no mapping is required, return the original buffer. + if (new_buffer_var.same_as(buf->data)) { + return buf; } - buf_remap_[key] = buf; - return buf; + // If the current scope already has a mapping of this buffer, use + // the mapped buffer. + auto key = buf.get(); + std::vector& buffers = buf_remap_[key]; + if (buffers.size() && buffers.back()->data.same_as(new_buffer_var)) { + return buffers.back(); + } + + // Otherwise, make and return a new buffer object that uses the + // new buffer, pushing it onto the scoped stack of existing + // buffers. This will be popped when the new_buffer_var + // redefinition is popped. + Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset, + buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, + buf->axis_separators, buf->span); + buffers.push_back(new_buf); + return new_buf; } Stmt VisitStmt_(const LetStmtNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { PrimExpr value = this->VisitExpr(op->value); - Var new_var(v->name_hint, v.dtype()); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); Stmt body = this->VisitStmt(op->body); - scope_[v.get()].pop_back(); - return LetStmt(new_var, value, body); + return LetStmt(redefine.new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -178,12 +192,10 @@ class IRConvertSSA final : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { const Var& v = op->loop_var; if (defined_.count(v.get())) { - Var new_var(v->name_hint, v.dtype()); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); - scope_[v.get()].pop_back(); op = stmt.as(); - return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, + return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, op->annotations); } else { defined_.insert(v.get()); @@ -193,12 +205,10 @@ class IRConvertSSA final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { const Var& v = op->buffer_var; if (defined_.count(v.get())) { - Var new_var(v->name_hint, v->type_annotation); - scope_[v.get()].push_back(new_var); + ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); - scope_[v.get()].pop_back(); op = stmt.as(); - return Allocate(new_var, op->dtype, op->extents, op->condition, op->body); + return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -219,9 +229,34 @@ class IRConvertSSA final : public StmtExprMutator { } private: + struct ScopedRedefine { + ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { + if (old_var->type_annotation.defined()) { + new_var = Var(old_var->name_hint, old_var->type_annotation); + } else { + new_var = Var(old_var->name_hint, old_var->dtype); + } + parent->scope_[old_var.get()].push_back(new_var); + } + + ~ScopedRedefine() { + parent->scope_[old_var.get()].pop_back(); + for (auto& kv : parent->buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + buffers.pop_back(); + } + } + } + + IRConvertSSA* parent; + Var old_var; + Var new_var; + }; + std::unordered_map> scope_; std::unordered_set defined_; - std::unordered_map buf_remap_; + std::unordered_map> buf_remap_; }; Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 7989fba2d29a..6dba694e45ac 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T import os @@ -110,7 +111,31 @@ def test_unroll_single_count_loops(): assert ret == stmt +def test_unroll_allocations(): + @tvm.script.ir_module + class before: + @T.prim_func + def main(): + for i in T.unroll(2): + with T.allocate([16], "float32", "global") as buf: + buf[0] = 0.0 + + @tvm.script.ir_module + class expected: + @T.prim_func + def main(): + with T.allocate([16], "float32", "global") as buf1: + buf1[0] = 0.0 + with T.allocate([16], "float32", "global") as buf2: + buf2[0] = 0.0 + + after = tvm.tir.transform.UnrollLoop()(before) + + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() + test_unroll_allocations() From 9fa1d07d6bf1c7701cfc7cdb609256bd1537175a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 22 Feb 2022 14:30:40 -0600 Subject: [PATCH 134/177] Ethos-u index/buffer dimension updates. --- .../contrib/test_ethosu/test_replace_copy.py | 34 +++--- .../contrib/test_ethosu/test_scheduler.py | 13 ++- .../test_ethosu/test_tir_to_cs_translator.py | 110 +++++++++--------- .../contrib/test_ethosu/test_vela_api.py | 10 +- 4 files changed, 88 insertions(+), 79 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index f952eca8fbd5..4bfbae5f03b7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -34,8 +34,10 @@ class ReferenceModule: def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([0], "uint8") - buffer_1 = T.buffer_decl([0], "uint8") + buffer = T.buffer_decl([80], "uint8") + buffer_1 = T.buffer_decl([304], "uint8") + T.preflattened_buffer(placeholder_3, [1, 16, 16, 32], dtype="int8", data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 8], dtype="int8", data=ethosu_write_1.data) # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) @@ -78,19 +80,23 @@ class WeightStream: def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([0], "uint8") - buffer_1 = T.buffer_decl([0], "uint8") - buffer_2 = T.buffer_decl([0], "uint8") - buffer_3 = T.buffer_decl([0], "uint8") + buffer = T.buffer_decl([416], "uint8") + buffer_1 = T.buffer_decl([112], "uint8") + buffer_2 = T.buffer_decl([272], "uint8") + buffer_3 = T.buffer_decl([64], "uint8") + T.preflattened_buffer(placeholder_5, [1, 16, 16, 32], dtype="int8", data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 16], dtype="int8", data=ethosu_write_1.data) # body - placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 416, 12, placeholder_d_global[0], 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 272, 12, placeholder_d_global[0], 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + placeholder_global_unrolled_iter_0 = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_global_unrolled_iter_1 = T.buffer_decl([272], "uint8", data=placeholder_global_unrolled_iter_0.data) + placeholder_d_global_unrolled_iter_0 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global_unrolled_iter_1 = T.buffer_decl([64], dtype="uint8", data=placeholder_d_global_unrolled_iter_0.data) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global_unrolled_iter_0[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global_unrolled_iter_0[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, 12, placeholder_d_global_unrolled_iter_0[0], 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global_unrolled_iter_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global_unrolled_iter_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, 12, placeholder_d_global_unrolled_iter_1[0], 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 57864218aab6..40f42df9eaad 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -180,12 +180,15 @@ def test_schedule_cache_reads(): @tvm.script.ir_module class DiamondGraphTir: @T.prim_func - def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffer[(1, 56, 56, 24), "int8"]) -> None: + def main(input_buffer: T.Buffer[(301056,), "int8"], output_buffer: T.Buffer[(75264,), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - weight_buffer = T.buffer_decl([], "uint8") - bias_buffer = T.buffer_decl([], "uint8") - weight_buffer2 = T.buffer_decl([], "uint8") - bias_buffer2 = T.buffer_decl([], "uint8") + T.preflattened_buffer(input_buffer, [1, 56, 56, 96], dtype='int8', data=input_buffer.data) + T.preflattened_buffer(output_buffer, [1, 56, 56, 24], dtype='int8', data=output_buffer.data) + + weight_buffer = T.buffer_decl([2608], "uint8") + bias_buffer = T.buffer_decl([240], "uint8") + weight_buffer2 = T.buffer_decl([736], "uint8") + bias_buffer2 = T.buffer_decl([240], "uint8") placeholder_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index fafc4d84ea5b..69501403ee4e 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -33,11 +33,11 @@ @tvm.script.ir_module class SingleEthosUConv2D: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(1024,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.buffer_decl([], "uint8") - placeholder_5 = T.buffer_decl([], "uint8") + placeholder_4 = T.buffer_decl([1], "uint8") + placeholder_5 = T.buffer_decl([1], "uint8") # body T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_4[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -48,13 +48,13 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu @tvm.script.ir_module class MultiEthosUConv2D: @T.prim_func - def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(placeholder_6: T.Buffer[(192,), "int8"], ethosu_conv2d_1: T.Buffer[(512,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_9 = T.buffer_decl([], "uint8") - placeholder_7 = T.buffer_decl([], "uint8") - placeholder_8 = T.buffer_decl([], "uint8") - placeholder_5 = T.buffer_decl([], "uint8") + placeholder_9 = T.buffer_decl([1], "uint8") + placeholder_7 = T.buffer_decl([1], "uint8") + placeholder_8 = T.buffer_decl([1], "uint8") + placeholder_5 = T.buffer_decl([1], "uint8") # body ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") @@ -70,11 +70,11 @@ def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffe @tvm.script.ir_module class MultiEthosUCopy: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_5 = T.buffer_decl([], "uint8") - placeholder_4 = T.buffer_decl([], "uint8") + placeholder_5 = T.buffer_decl([1], "uint8") + placeholder_4 = T.buffer_decl([1], "uint8") # body placeholder_global = T.allocate([256], "uint8", "global") placeholder_d_global = T.allocate([8], "int32", "global") @@ -89,15 +89,15 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu @tvm.script.ir_module class WeightStreamOnly: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") - buffer_4 = T.buffer_decl([], "uint8") - buffer_5 = T.buffer_decl([], "uint8") - buffer_6 = T.buffer_decl([], "uint8") - buffer_7 = T.buffer_decl([], "uint8") + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + buffer = T.buffer_decl([1], "uint8") + buffer_1 = T.buffer_decl([1], "uint8") + buffer_2 = T.buffer_decl([1], "uint8") + buffer_3 = T.buffer_decl([1], "uint8") + buffer_4 = T.buffer_decl([1], "uint8") + buffer_5 = T.buffer_decl([1], "uint8") + buffer_6 = T.buffer_decl([1], "uint8") + buffer_7 = T.buffer_decl([1], "uint8") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -133,17 +133,17 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ @tvm.script.ir_module class MixedRead: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") - buffer_4 = T.buffer_decl([], "uint8") - buffer_5 = T.buffer_decl([], "uint8") - buffer_6 = T.buffer_decl([], "uint8") - buffer_7 = T.buffer_decl([], "uint8") - buffer_8 = T.buffer_decl([], "uint8") - buffer_9 = T.buffer_decl([], "uint8") + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + buffer = T.buffer_decl([1], "uint8") + buffer_1 = T.buffer_decl([1], "uint8") + buffer_2 = T.buffer_decl([1], "uint8") + buffer_3 = T.buffer_decl([1], "uint8") + buffer_4 = T.buffer_decl([1], "uint8") + buffer_5 = T.buffer_decl([1], "uint8") + buffer_6 = T.buffer_decl([1], "uint8") + buffer_7 = T.buffer_decl([1], "uint8") + buffer_8 = T.buffer_decl([1], "uint8") + buffer_9 = T.buffer_decl([1], "uint8") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -523,10 +523,10 @@ class SingleEthosuDepthwiseConv2D: def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_depthwise_conv2d: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_1, [3, 3, 2, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_2, [3, 10], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [1, 6, 7, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_1, [18], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_2, [30], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [126], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, ethosu_depthwise_conv2d_1[0], 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, placeholder_4[0], 18, 13, placeholder_5[0], 30, 0, 0, 0, 0, "CLIP", 15, 105, "TFL", "NONE", dtype="int8")) __tvm_meta__ = None @@ -665,10 +665,10 @@ def populate_ethosu_copy_calls(stmt): @tvm.script.ir_module class MixedConstantDatatypes: @T.prim_func - def main(placeholder_4: T.Buffer[(1, 8, 16, 16), "int8"], ethosu_write_1: T.Buffer[(1, 1, 1, 16), "int8"]) -> None: - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "int16") + def main(placeholder_4: T.Buffer[(2048,), "int8"], ethosu_write_1: T.Buffer[(16,), "int8"]) -> None: + buffer = T.buffer_decl([1], "uint8") + buffer_1 = T.buffer_decl([1], "uint8") + buffer_2 = T.buffer_decl([1], "int16") # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True, @@ -961,8 +961,8 @@ class SingleEthosuPooling: def main(placeholder: T.handle, placeholder_3: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 5, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [75], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, placeholder_4[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "TFL", "NONE", dtype="int8")) __tvm_meta__ = None @@ -1038,7 +1038,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1 ) ethosu_write_2 = T.match_buffer( - ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1 + ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1 ) # body T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, "TFL", dtype="int8")) @@ -1055,7 +1055,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1070,7 +1070,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1086,7 +1086,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1102,7 +1102,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1118,7 +1118,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, "TFL", dtype="int32")) __tvm_meta__ = None @@ -1134,7 +1134,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, "TFL", dtype="int32")) __tvm_meta__ = None @@ -1255,7 +1255,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1270,7 +1270,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1285,7 +1285,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1301,7 +1301,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1317,7 +1317,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, "TFL", dtype="int8")) __tvm_meta__ = None @@ -1333,7 +1333,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, "TFL", dtype="int32")) __tvm_meta__ = None @@ -1349,7 +1349,7 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) - ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, "TFL", dtype="int32")) __tvm_meta__ = None diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index 5e4aaad304a8..662b35822cc2 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -50,7 +50,7 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_3 = T.match_buffer( - placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + placeholder, [192], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) placeholder_4 = T.match_buffer( placeholder_1, [48], dtype="uint8", elem_offset=0, align=128, offset_factor=1 @@ -59,7 +59,7 @@ def main( placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 ) ethosu_conv2d_1 = T.match_buffer( - ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ethosu_conv2d, [1024], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) # body T.evaluate( @@ -142,10 +142,10 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) placeholder_3 = T.match_buffer( - placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + placeholder, [192], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) placeholder_4 = T.match_buffer( - placeholder_1, [16, 1, 1, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + placeholder_1, [48], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) placeholder_5 = T.match_buffer( placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 @@ -155,7 +155,7 @@ def main( placeholder_6, [16], dtype="float32", elem_offset=0, align=128, offset_factor=1 ) ethosu_conv2d_1 = T.match_buffer( - ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 + ethosu_conv2d, [1024], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) # body T.evaluate( From 3edb07d28a071956ecd9655009b22ee333d735f3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 22 Feb 2022 15:05:02 -0600 Subject: [PATCH 135/177] Updated ethosu passes to handle buffer load/store. --- .../backend/contrib/ethosu/tir/passes.py | 341 ++++++++++-------- 1 file changed, 191 insertions(+), 150 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 461a6f648722..baa91cdf57c6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -315,178 +315,219 @@ def EncodeConstants(const_dict): new_const_dict = {} buffer_to_const = {} pointer_to_buffer = {} - encoded_buffers = set() rewrite_buffer = {} rewrite_pointer = {} - accel_config = vela_api.get_accelerator_config() - - def _align_scale_bias(tir_extern_call, bias): - """Align the scale_bias to 16 bytes.""" - value_bytes = bytearray() - value_bytes.extend(bias.tobytes()) - # Align to 16 - remainder = (len(value_bytes)) % 16 - if remainder > 0: - value_bytes.extend(bytearray(16 - remainder)) - value = np.frombuffer(value_bytes, dtype="uint8") - return value - - def _encode_weights(tir_extern_call, weights): - """Encode the weights for a TIR extern call.""" - value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) - value = np.frombuffer(value_bytes, dtype="uint8") - return value - - def _new_buffer(old_buffer, new_value): - """Create a new buffer and add the old buffer and its pointer to the - rewriting maps.""" - if old_buffer in rewrite_buffer: - new_buffer = rewrite_buffer[old_buffer] - else: - new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) - pointer_to_buffer[new_buffer.data] = new_buffer - buffer_to_const[new_buffer] = new_value - - rewrite_buffer[old_buffer] = new_buffer - rewrite_pointer[old_buffer.data] = new_buffer.data - encoded_buffers.add(new_buffer) - - def _visit_encode_pre(stmt): - if isinstance(stmt, tvm.tir.Call): - # Handle copies as a special-case by propagating the buffer information - # from the read to the write pointer. - if stmt.args[0] == "ethosu_copy": - read_pointer = stmt.args[1].buffer.data - if read_pointer in pointer_to_buffer: - write_pointer = stmt.args[3].buffer.data + + def collect_encoding_definitions(stmt, unencoded_buffer_constants): + # Map from copy destination to copy source. + copy_map = {} + # List of buffer copies that occurred + copied_buffers = [] + # List of encoded buffer information + constant_buffer_replacements = [] + + def _align_scale_bias(tir_extern_call, bias): + """Align the scale_bias to 16 bytes.""" + value_bytes = bytearray() + value_bytes.extend(bias.tobytes()) + # Align to 16 + remainder = (len(value_bytes)) % 16 + if remainder > 0: + value_bytes.extend(bytearray(16 - remainder)) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + accel_config = vela_api.get_accelerator_config() + + def _encode_weights(tir_extern_call, weights): + """Encode the weights for a TIR extern call.""" + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + def _declare_constant_buffer(old_buffer, encoded_constants): + """Create a new buffer and add the old buffer and its pointer to the + rewriting maps.""" + new_buffer = tvm.tir.decl_buffer( + shape=[len(encoded_constants)], + dtype=str(encoded_constants.dtype), + name=old_buffer.name + "_encoded", + ) + + constant_buffer_replacements.append( + { + "old_buffer": old_buffer, + "new_buffer": new_buffer, + "encoded_constants": encoded_constants, + } + ) + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + # Handle copies as a special-case by propagating the buffer information + # from the read to the write pointer. + if stmt.args[0] == "ethosu_copy": + read_buffer = stmt.args[1].buffer + write_buffer = stmt.args[3].buffer # Assert writing to the base of the write_var (pre-StorageRewrite) assert list(stmt.args[3].indices) == [0] assert list(stmt.args[1].indices) == [0] - pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer] - rewrite_buffer[stmt.args[3].buffer] = stmt.args[1].buffer - else: - # Encode the weights - old_weights_buffer = get_weights_buffer(stmt) - if old_weights_buffer is not None: - assert old_weights_buffer.data in pointer_to_buffer - new_weights_buffer = pointer_to_buffer[old_weights_buffer.data] - weights_value = buffer_to_const[new_weights_buffer] - new_weights_value = _encode_weights(stmt, weights_value) - _new_buffer(new_weights_buffer, new_weights_value) - # Align the scale_bias to 16 bytes - old_scale_bias_buffer = get_scale_bias_buffer(stmt) - if old_scale_bias_buffer is not None: - assert old_scale_bias_buffer.data in pointer_to_buffer - new_scale_bias_buffer = pointer_to_buffer[old_scale_bias_buffer.data] - scale_bias_value = buffer_to_const[new_scale_bias_buffer] - new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) - _new_buffer(new_scale_bias_buffer, new_scale_bias_value) - - def _visit_encode_post(stmt): - # Because encoding may change the data type (e.g. bias to uint8) and type information - # is stored in pointer vars, it's necessary to rewrite all the pointers which point - # to encoded data. - if isinstance(stmt, tvm.tir.Allocate): - allocate_pointer = stmt.buffer_var - if allocate_pointer in pointer_to_buffer: - buffer = pointer_to_buffer[allocate_pointer] - if buffer in rewrite_buffer: # If the pointer needs rewriting - # Create a new pointer var with the type of the new buffer - new_buffer = rewrite_buffer[buffer] - storage_type = tvm.ir.PrimType(new_buffer.dtype) - new_pointer = tvm.tir.Var( - allocate_pointer.name, - tvm.ir.PointerType(storage_type, buffer.scope()), - allocate_pointer.span, - ) - # Set the new pointer to resolve to the new buffer - pointer_to_buffer[new_pointer] = new_buffer - # Add the old pointer to the pointer rewriting dict - rewrite_pointer[allocate_pointer] = new_pointer - - def _visit_rewrite(stmt): - if isinstance(stmt, tvm.tir.Call): - # For extern calls, we need to rewrite pairs of arguments corresponding to - # base address load and the length of the load. - new_args = [stmt.args[0]] - for i in range(1, len(stmt.args)): - # If the previous argument was a load, the current should be a length - if isinstance(stmt.args[i - 1], tvm.tir.BufferLoad): - load = stmt.args[i - 1] - old_buffer = load.buffer - if old_buffer.data in pointer_to_buffer: - new_buffer = pointer_to_buffer[old_buffer.data] - # Only rewrite the arguments of buffers that have been encoded - if new_buffer in encoded_buffers: - new_arg = np.prod(list(new_buffer.shape)) - new_args.append(new_arg) - continue - new_args.append(stmt.args[i]) - - return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) - if isinstance(stmt, tvm.tir.Allocate): - # Where a pointer needs rewriting, the allocate for it must be rewritten - allocate_pointer = stmt.buffer_var - if allocate_pointer in pointer_to_buffer: - if pointer_to_buffer[allocate_pointer] in rewrite_buffer: - new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]] - new_pointer = rewrite_pointer[allocate_pointer] + copied_buffers.append({"source": read_buffer, "dest": write_buffer}) + copy_map[write_buffer] = read_buffer + + else: + # Encode the weights + weights_buffer = get_weights_buffer(stmt) + if weights_buffer is not None: + weights_buffer = copy_map[weights_buffer] + unencoded_weights_value = unencoded_buffer_constants[weights_buffer] + encoded_weights_value = _encode_weights(stmt, unencoded_weights_value) + _declare_constant_buffer(weights_buffer, encoded_weights_value) + + # Align the scale_bias to 16 bytes + scale_bias_buffer = get_scale_bias_buffer(stmt) + if scale_bias_buffer is not None: + scale_bias_buffer = copy_map[scale_bias_buffer] + scale_bias_value = unencoded_buffer_constants[scale_bias_buffer] + aligned_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) + _declare_constant_buffer(scale_bias_buffer, aligned_scale_bias_value) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + + return { + "copied_buffers": copied_buffers, + "constant_buffer_replacements": constant_buffer_replacements, + } + + def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, encoded_buffers): + def _visit_rewrite(stmt): + if isinstance(stmt, tvm.tir.Call): + # For extern calls, we need to rewrite pairs of arguments corresponding to + # base address load and the length of the load. + old_args = list(stmt.args) + + new_args = [stmt.args[0]] + for prev_arg, arg in zip(old_args[:-1], old_args[1:]): + # If the previous argument was a load from an + # encoded buffer, the current should be a length. + if ( + isinstance(prev_arg, tvm.tir.BufferLoad) + and prev_arg.buffer in encoded_buffers + ): + arg = np.prod(list(prev_arg.buffer.shape)) + + new_args.append(arg) + + return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) + + if isinstance(stmt, tvm.tir.Allocate): + # Where a pointer needs rewriting, the allocate for it must be rewritten + allocate_pointer = stmt.buffer_var + if allocate_pointer in var_remap: + new_allocate_pointer = var_remap[allocate_pointer] + new_buffer = pointer_to_buffer[new_allocate_pointer] + return tvm.tir.Allocate( - new_pointer, + new_buffer.data, new_buffer.dtype, new_buffer.shape, stmt.condition, stmt.body, stmt.span, ) - # The following rewrites would be better expressed by just rewriting the Vars, however - # ir_transform doesn't seem to visit Vars. So instead we do the next best thing and rewrite - # the nodes which contain the Vars. - if isinstance(stmt, tvm.tir.BufferLoad): - if stmt.buffer.data in pointer_to_buffer: - load_buffer = pointer_to_buffer[stmt.buffer.data] - if load_buffer in rewrite_buffer: - new_buffer = rewrite_buffer[load_buffer] - return tvm.tir.BufferLoad(new_buffer, stmt.indices, stmt.span) - if isinstance(stmt, tvm.tir.AttrStmt): - node_pointer = stmt.node - if node_pointer in rewrite_pointer: - return tvm.tir.AttrStmt( - rewrite_pointer[node_pointer], stmt.attr_key, stmt.value, stmt.body, stmt.span - ) - return None - def _ftransform(f, mod, ctx): - for i, param in enumerate(f.params): - if i in const_dict: - buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() - pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param] + # The following rewrites would be better expressed by just + # rewriting the Buffers. However ir_transform doesn't + # visit Buffers, so instead we do the next best thing and + # rewrite the nodes which contain the Buffers. + if isinstance(stmt, tvm.tir.BufferLoad): + if stmt.buffer in buf_remap: + return tvm.tir.BufferLoad(buf_remap[stmt.buffer], stmt.indices, stmt.span) + + if isinstance(stmt, tvm.tir.AttrStmt): + node_pointer = stmt.node + if node_pointer in var_remap: + return tvm.tir.AttrStmt( + var_remap[node_pointer], + stmt.attr_key, + stmt.value, + stmt.body, + stmt.span, + ) - # First analyse what needs to be rewritten - new_body = tvm.tir.stmt_functor.ir_transform( - f.body, _visit_encode_pre, _visit_encode_post, ["tir.Call", "tir.Allocate"] - ) - # Then perform the rewrites - new_body = tvm.tir.stmt_functor.ir_transform( - f.body, + return None + + return tvm.tir.stmt_functor.ir_transform( + stmt, None, _visit_rewrite, ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], ) + + def _ftransform(f, mod, ctx): + # Step 0: Unpack the constant dictionary in terms of the + # functions buffers. + unencoded_buffer_constants = {} + for i, param in enumerate(f.params): + if i in const_dict: + unencoded_buffer_constants[f.buffer_map[param]] = const_dict[i].flatten() + + # Step 1: Collect information on the buffers that will be + # replaced by encodings. + buffer_information = collect_encoding_definitions(f.body, unencoded_buffer_constants) + + # Step 2: Generate variable/buffer remaps, based on the + # collected information. + buf_remap = {} + encoded_buffers = [] + + # Any encoded buffers must be replaced + for info in buffer_information["constant_buffer_replacements"]: + buf_remap[info["old_buffer"]] = info["new_buffer"] + encoded_buffers.append(info["new_buffer"]) + + # Any buffers that are copied into from an encoded buffer must + # be replaced. + for info in buffer_information["copied_buffers"]: + copy_source = info["source"] + while copy_source in buf_remap: + copy_source = buf_remap[copy_source] + + copy_dest = info["dest"] + + if copy_source.shape != copy_dest.shape or copy_source.dtype != copy_dest.dtype: + new_dest = tvm.tir.decl_buffer( + shape=copy_source.shape, + dtype=copy_source.dtype, + name=copy_dest.name + "_encoded", + ) + buf_remap[copy_dest] = new_dest + encoded_buffers.append(new_dest) + + # Define additional dependent lookup tables. + var_remap = {old.data: new.data for (old, new) in buf_remap.items()} + pointer_to_buffer = { + buf.data: buf for (old, new) in buf_remap.items() for buf in [old, new] + } + buffer_to_const = { + info["new_buffer"]: info["encoded_constants"] + for info in buffer_information["constant_buffer_replacements"] + } + + # Step 3: Then perform the rewrites + new_body = transform_stmt(f.body, buf_remap, var_remap, pointer_to_buffer, encoded_buffers) + + # Step 4: Rewrite the buffer map and const dict to instead use the encoded versions new_buffer_map = {} - # Rewrite the buffer map and const dict to instead use the encoded versions for i, param in enumerate(f.params): buffer = f.buffer_map[param] - if buffer in rewrite_buffer: - new_buffer = rewrite_buffer[buffer] - new_buffer_map[param] = new_buffer - new_value = buffer_to_const[new_buffer] - new_const_dict[i] = new_value - elif buffer in buffer_to_const: + if buffer in buf_remap: + buffer = buf_remap[buffer] + + if buffer in buffer_to_const: new_const_dict[i] = buffer_to_const[buffer] - new_buffer_map[param] = buffer - else: - new_buffer_map[param] = buffer + + new_buffer_map[param] = buffer new_f = tvm.tir.PrimFunc( f.params, From ea0b4f942239e95fa4267979298c6cbc10a1705e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 23 Feb 2022 10:19:47 -0600 Subject: [PATCH 136/177] Resolved bug in tvmscript printing of duplicate buffers. --- src/printer/tvmscript_printer.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 7931b1cddea5..4e5187a91da2 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1014,10 +1014,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { // value of T.allocate, and no T.buffer_decl statement is needed. Buffer alloc_buf(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0, 0, kDefault); + bool found_alloc_buf = false; Array aliasing_buffers; for (const auto& buf : buffer_usage) { - if (is_exact_match(buf, alloc_buf)) { + if (!found_alloc_buf && is_exact_match(buf, alloc_buf)) { alloc_buf = buf; + found_alloc_buf = true; } else { aliasing_buffers.push_back(buf); } From 3f52fa391a58d201d85371392d6f3f2bfbcca842 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 23 Feb 2022 08:46:07 -0600 Subject: [PATCH 137/177] Fix breakage in ethos-u test_assign_addresses, encode constants --- .../backend/contrib/ethosu/tir/passes.py | 61 +++++---- .../contrib/ethosu/tir_to_cs_translator.py | 2 + tests/python/contrib/test_ethosu/infra.py | 4 +- .../test_ethosu/test_encode_constants.py | 94 ++++++++------ .../test_ethosu/test_remove_concatenates.py | 21 +-- .../test_ethosu/test_replace_conv2d.py | 120 +++++++++++------- .../test_ethosu/test_tir_to_cs_translator.py | 18 +-- 7 files changed, 183 insertions(+), 137 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index baa91cdf57c6..3fb1334801a4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -245,7 +245,9 @@ def _visit(stmt): # If it's anything other than a full read, create a new buffer if offset != 0 or len(const) != length: new_consts.append(const[offset : offset + length]) - new_buffer = tvm.tir.decl_buffer((length,), arg.dtype) + new_buffer = tvm.tir.decl_buffer( + (length,), arg.dtype, scope=arg.buffer.scope() + ) new_buffers.append(new_buffer) new_args.append(tvm.tir.expr.BufferLoad(new_buffer, [0])) continue @@ -313,12 +315,11 @@ def EncodeConstants(const_dict): """ new_const_dict = {} - buffer_to_const = {} pointer_to_buffer = {} rewrite_buffer = {} rewrite_pointer = {} - def collect_encoding_definitions(stmt, unencoded_buffer_constants): + def collect_encoding_definitions(stmt, old_buffer_to_const): # Map from copy destination to copy source. copy_map = {} # List of buffer copies that occurred @@ -352,6 +353,7 @@ def _declare_constant_buffer(old_buffer, encoded_constants): shape=[len(encoded_constants)], dtype=str(encoded_constants.dtype), name=old_buffer.name + "_encoded", + scope=old_buffer.scope(), ) constant_buffer_replacements.append( @@ -379,16 +381,18 @@ def _visit(stmt): # Encode the weights weights_buffer = get_weights_buffer(stmt) if weights_buffer is not None: - weights_buffer = copy_map[weights_buffer] - unencoded_weights_value = unencoded_buffer_constants[weights_buffer] + if weights_buffer in copy_map: + weights_buffer = copy_map[weights_buffer] + unencoded_weights_value = old_buffer_to_const[weights_buffer] encoded_weights_value = _encode_weights(stmt, unencoded_weights_value) _declare_constant_buffer(weights_buffer, encoded_weights_value) # Align the scale_bias to 16 bytes scale_bias_buffer = get_scale_bias_buffer(stmt) if scale_bias_buffer is not None: - scale_bias_buffer = copy_map[scale_bias_buffer] - scale_bias_value = unencoded_buffer_constants[scale_bias_buffer] + if scale_bias_buffer in copy_map: + scale_bias_buffer = copy_map[scale_bias_buffer] + scale_bias_value = old_buffer_to_const[scale_bias_buffer] aligned_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) _declare_constant_buffer(scale_bias_buffer, aligned_scale_bias_value) @@ -399,7 +403,7 @@ def _visit(stmt): "constant_buffer_replacements": constant_buffer_replacements, } - def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, encoded_buffers): + def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const): def _visit_rewrite(stmt): if isinstance(stmt, tvm.tir.Call): # For extern calls, we need to rewrite pairs of arguments corresponding to @@ -412,7 +416,7 @@ def _visit_rewrite(stmt): # encoded buffer, the current should be a length. if ( isinstance(prev_arg, tvm.tir.BufferLoad) - and prev_arg.buffer in encoded_buffers + and prev_arg.buffer in new_buffer_to_const ): arg = np.prod(list(prev_arg.buffer.shape)) @@ -467,24 +471,24 @@ def _visit_rewrite(stmt): def _ftransform(f, mod, ctx): # Step 0: Unpack the constant dictionary in terms of the # functions buffers. - unencoded_buffer_constants = {} + old_buffer_to_const = {} for i, param in enumerate(f.params): if i in const_dict: - unencoded_buffer_constants[f.buffer_map[param]] = const_dict[i].flatten() + old_buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() # Step 1: Collect information on the buffers that will be # replaced by encodings. - buffer_information = collect_encoding_definitions(f.body, unencoded_buffer_constants) + buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const) # Step 2: Generate variable/buffer remaps, based on the # collected information. buf_remap = {} - encoded_buffers = [] + new_buffer_to_const = {} # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: buf_remap[info["old_buffer"]] = info["new_buffer"] - encoded_buffers.append(info["new_buffer"]) + new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"] # Any buffers that are copied into from an encoded buffer must # be replaced. @@ -499,23 +503,23 @@ def _ftransform(f, mod, ctx): new_dest = tvm.tir.decl_buffer( shape=copy_source.shape, dtype=copy_source.dtype, - name=copy_dest.name + "_encoded", + name=copy_dest.name, + scope=copy_dest.scope(), ) buf_remap[copy_dest] = new_dest - encoded_buffers.append(new_dest) + if copy_source in new_buffer_to_const: + new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source] # Define additional dependent lookup tables. var_remap = {old.data: new.data for (old, new) in buf_remap.items()} pointer_to_buffer = { buf.data: buf for (old, new) in buf_remap.items() for buf in [old, new] } - buffer_to_const = { - info["new_buffer"]: info["encoded_constants"] - for info in buffer_information["constant_buffer_replacements"] - } # Step 3: Then perform the rewrites - new_body = transform_stmt(f.body, buf_remap, var_remap, pointer_to_buffer, encoded_buffers) + new_body = transform_stmt( + f.body, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const + ) # Step 4: Rewrite the buffer map and const dict to instead use the encoded versions new_buffer_map = {} @@ -524,8 +528,10 @@ def _ftransform(f, mod, ctx): if buffer in buf_remap: buffer = buf_remap[buffer] - if buffer in buffer_to_const: - new_const_dict[i] = buffer_to_const[buffer] + if buffer in new_buffer_to_const: + new_const_dict[i] = new_buffer_to_const[buffer] + elif buffer in old_buffer_to_const: + new_const_dict[i] = old_buffer_to_const[buffer] new_buffer_map[param] = buffer @@ -774,11 +780,12 @@ def _ftransform(f, mod, ctx): # We are using buffer_var to key the constants as # PrimFunc params of constants will be removed. new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx] - for i in range(len(f.params)): + for i, param in enumerate(f.params): if i not in const_dict.keys(): - new_params.append(f.params[i]) - new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] - new_preflattened_buffer_map[f.params[i]] = f.preflattened_buffer_map[f.params[i]] + new_params.append(param) + new_buffer_map[param] = f.buffer_map[param] + if param in f.preflattened_buffer_map: + new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param] return tvm.tir.PrimFunc( new_params, f.body, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index ce5cfd49883d..9b1bbf4d03c0 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -334,6 +334,8 @@ def extract_buffer_info( primfunc = mod.functions.items()[0][1] for param, const_data in param_dict.items(): + if isinstance(param, tvm.tir.Buffer): + param = param.data buffer_info[param] = BufferInfo( const_data, const_data.shape, const_data.dtype, BufferType.constant ) diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index d79fe679d482..934ce6929bc7 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -366,8 +366,8 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False): continue elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): conv_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - conv_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + conv_args.append(arg.indices[0]) else: conv_args.append(arg) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 7bf0c9a181aa..8878e467aad7 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -34,32 +34,36 @@ @tvm.script.ir_module class WeightStreamOnly: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") - buffer_4 = T.buffer_decl([], "uint8") - buffer_5 = T.buffer_decl([], "uint8") - buffer_6 = T.buffer_decl([], "uint8") - buffer_7 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([128], "uint8") + buffer_1 = T.buffer_decl([32], "uint8") + buffer_2 = T.buffer_decl([112], "uint8") + buffer_3 = T.buffer_decl([32], "uint8") + buffer_4 = T.buffer_decl([112], "uint8") + buffer_5 = T.buffer_decl([32], "uint8") + buffer_6 = T.buffer_decl([112], "uint8") + buffer_7 = T.buffer_decl([32], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body - placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + p1_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1_global_1 = T.buffer_decl([112], dtype="uint8", data=p1_global.data) + p2_global_1 = T.buffer_decl([32], dtype="uint8", data=p2_global.data) + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, p1_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, p2_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global[0], 128, 12, p2_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, p1_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, p2_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, p1_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, p2_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, p1_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, p2_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -107,11 +111,13 @@ def _get_func(): @tvm.script.ir_module class RereadWeights: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([304], "uint8") + buffer_1 = T.buffer_decl([80], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) @@ -168,13 +174,15 @@ def _get_func(): @tvm.script.ir_module class DirectReadOnly: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([592], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + buffer_2 = T.buffer_decl([160], "uint8") + buffer_3 = T.buffer_decl([80], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -225,19 +233,21 @@ def _get_func(): @tvm.script.ir_module class MixedRead: @T.prim_func - def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") - buffer_4 = T.buffer_decl([], "uint8") - buffer_5 = T.buffer_decl([], "uint8") - buffer_6 = T.buffer_decl([], "uint8") - buffer_7 = T.buffer_decl([], "uint8") - buffer_8 = T.buffer_decl([], "uint8") - buffer_9 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([592], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + buffer_2 = T.buffer_decl([80], "uint8") + buffer_3 = T.buffer_decl([32], "uint8") + buffer_4 = T.buffer_decl([80], "uint8") + buffer_5 = T.buffer_decl([32], "uint8") + buffer_6 = T.buffer_decl([80], "uint8") + buffer_7 = T.buffer_decl([32], "uint8") + buffer_8 = T.buffer_decl([80], "uint8") + buffer_9 = T.buffer_decl([32], "uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index d9b08d521be5..f82351c28c05 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -30,17 +30,20 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.Buffer[(1, 8, 12, 16), "int8"], placeholder_1: T.Buffer[(1, 8, 10, 16), "int8"], T_concat: T.Buffer[(1, 8, 32, 16), "int8"]) -> None: + def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,), "int8"], T_concat: T.Buffer[(4096,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") - buffer_4 = T.buffer_decl([], "uint8") - buffer_5 = T.buffer_decl([], "uint8") - buffer_6 = T.buffer_decl([], "uint8") - buffer_7 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([2992], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + buffer_2 = T.buffer_decl([2992], "uint8") + buffer_3 = T.buffer_decl([160], "uint8") + buffer_4 = T.buffer_decl([2992], "uint8") + buffer_5 = T.buffer_decl([160], "uint8") + buffer_6 = T.buffer_decl([2992], "uint8") + buffer_7 = T.buffer_decl([160], "uint8") + T.preflattened_buffer(placeholder, [1, 8, 12, 16], "int8", data=placeholder.data) + T.preflattened_buffer(placeholder_1, [1, 8, 10, 16], "int8", data=placeholder_1.data) + T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index d9a31a9b86f3..5a9aa9855183 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -333,13 +333,15 @@ def _visit(stmt): @tvm.script.ir_module class Conv2dDoubleCascade1: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([304], "uint8") + buffer_1 = T.buffer_decl([80], "uint8") + buffer_2 = T.buffer_decl([320], "uint8") + buffer_3 = T.buffer_decl([160], "uint8") + T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -352,13 +354,15 @@ def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer @tvm.script.ir_module class Conv2dDoubleCascade2: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([80], "uint8") + buffer_1 = T.buffer_decl([320], "uint8") + buffer_2 = T.buffer_decl([1312], "uint8") + buffer_3 = T.buffer_decl([2608], "uint8") + T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -371,13 +375,15 @@ def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer @tvm.script.ir_module class Conv2dDoubleCascade3: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([1744], "uint8") + buffer_1 = T.buffer_decl([80], "uint8") + buffer_2 = T.buffer_decl([320], "uint8") + buffer_3 = T.buffer_decl([880], "uint8") + T.preflattened_buffer(placeholder_5, [1, 16, 16, 3], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -392,13 +398,15 @@ def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buff @tvm.script.ir_module class Conv2dDoubleCascade4: @T.prim_func - def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None: + def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([1456], "uint8") + buffer_1 = T.buffer_decl([352], "uint8") + buffer_2 = T.buffer_decl([272], "uint8") + buffer_3 = T.buffer_decl([11040], "uint8") + T.preflattened_buffer(placeholder_5, [1, 8, 1, 8, 16], 'int8', data=placeholder_5.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -411,13 +419,15 @@ def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Bu @tvm.script.ir_module class Conv2dDoubleCascade5: @T.prim_func - def main(placeholder: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write: T.Buffer[(1, 32, 32, 8), "int8"]) -> None: + def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([320], "uint8") + buffer_2 = T.buffer_decl([304], "uint8") + buffer_3 = T.buffer_decl([80], "uint8") + T.preflattened_buffer(placeholder, [1, 8, 8, 3], 'int8', data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) @@ -430,13 +440,15 @@ def main(placeholder: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write: T.Buffer[(1, @tvm.script.ir_module class Conv2dDoubleCascade6: @T.prim_func - def main(placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write: T.Buffer[(1, 32, 2, 32, 16), "int8"]) -> None: + def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") - buffer_2 = T.buffer_decl([], "uint8") - buffer_3 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([1456], "uint8") + buffer_1 = T.buffer_decl([352], "uint8") + buffer_2 = T.buffer_decl([11040], "uint8") + buffer_3 = T.buffer_decl([272], "uint8") + T.preflattened_buffer(placeholder, [1, 8, 1, 8, 16], 'int8', data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) @@ -591,11 +603,13 @@ def _get_func( @tvm.script.ir_module class Conv2dInlineCopy1: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([848], "uint8") + buffer_1 = T.buffer_decl([160], "uint8") + T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -604,11 +618,13 @@ def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buff @tvm.script.ir_module class Conv2dInlineCopy2: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([656], "uint8") + T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, 12, buffer[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -646,11 +662,13 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): @tvm.script.ir_module class Conv2dInlineReshape1: @T.prim_func - def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -660,11 +678,13 @@ def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer @tvm.script.ir_module class Conv2dInlineReshape2: @T.prim_func - def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -674,11 +694,13 @@ def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[( @tvm.script.ir_module class Conv2dInlineReshape3: @T.prim_func - def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -688,11 +710,13 @@ def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, @tvm.script.ir_module class Conv2dInlineReshape4: @T.prim_func - def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.buffer_decl([], "uint8") - buffer_1 = T.buffer_decl([], "uint8") + buffer = T.buffer_decl([160], "uint8") + buffer_1 = T.buffer_decl([848], "uint8") + T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data) + T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index c46114ebba79..8169f7b86d5b 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -73,7 +73,7 @@ class MultiEthosUCopy: def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(2048,), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_5 = T.buffer_decl([1], "uint8") + placeholder_5 = T.buffer_decl([1], "int32") placeholder_4 = T.buffer_decl([1], "uint8") # body placeholder_global = T.allocate([256], "uint8", "global") @@ -655,8 +655,8 @@ def populate_ethosu_copy_calls(stmt): ethosu_copy_calls = extract_ethosu_copy_extern_calls(test_case["tir_module"]) for idx, ethosu_copy_call in enumerate(ethosu_copy_calls): npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_call_extern(ethosu_copy_call) - assert npu_dma_op.src.address.buffer_var.name == test_case["ref"][idx]["src"] - assert npu_dma_op.dest.address.buffer_var.name == test_case["ref"][idx]["dest"] + assert npu_dma_op.src.address.buffer.name == test_case["ref"][idx]["src"] + assert npu_dma_op.dest.address.buffer.name == test_case["ref"][idx]["dest"] assert npu_dma_op.src.length == test_case["ref"][idx]["length"] assert npu_dma_op.dest.length == test_case["ref"][idx]["length"] @@ -901,11 +901,11 @@ def check_buffer(address, region, length, buffer_var): for npu_op in npu_ops: if isinstance(npu_op, vapi.NpuDmaOperation): - src_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + src_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer.data check_buffer( npu_op.src.address, npu_op.src.region, npu_op.src.length, src_tir_buffer_var ) - dest_tir_load = npu_op_tir_buffers[npu_op][1].buffer_var + dest_tir_load = npu_op_tir_buffers[npu_op][1].buffer.data check_buffer( npu_op.dest.address, npu_op.dest.region, @@ -913,7 +913,7 @@ def check_buffer(address, region, length, buffer_var): dest_tir_load, ) elif issubclass(type(npu_op), vapi.NpuBlockOperation): - ifm_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + ifm_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer.data ifm_length = ( npu_op.ifm.shape.height * npu_op.ifm.shape.width * npu_op.ifm.shape.depth ) @@ -923,7 +923,7 @@ def check_buffer(address, region, length, buffer_var): ifm_length, ifm_tir_buffer_var, ) - ofm_tir_buffer_var = npu_op_tir_buffers[npu_op][1].buffer_var + ofm_tir_buffer_var = npu_op_tir_buffers[npu_op][1].buffer.data ofm_length = ( npu_op.ofm.shape.height * npu_op.ofm.shape.width * npu_op.ofm.shape.depth ) @@ -939,7 +939,7 @@ def check_buffer(address, region, length, buffer_var): npu_op.weights[idx].address, npu_op.weights[idx].region, npu_op.weights[idx].length, - weight.address.buffer_var, + weight.address.buffer.data, ) for idx, bias in enumerate(npu_op_tir_buffers[npu_op][3]): assert isinstance(bias, vapi.NpuAddressRange) @@ -947,7 +947,7 @@ def check_buffer(address, region, length, buffer_var): npu_op.biases[idx].address, npu_op.biases[idx].region, npu_op.biases[idx].length, - bias.address.buffer_var, + bias.address.buffer.data, ) for test_case in test_cases: From b476517b10fa7776aba687f5c6d36a50152e3da9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Feb 2022 10:50:36 -0600 Subject: [PATCH 138/177] Apply same changes to T.allocate_const as to T.allocate Return a buffer when used in TVMScript, allow for aliasing buffers. --- python/tvm/script/tir/scope_handler.py | 18 +-- src/printer/tvmscript_printer.cc | 66 +++++++---- .../unittest/test_tvmscript_roundtrip.py | 105 ++++++------------ 3 files changed, 90 insertions(+), 99 deletions(-) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index a86a6942dd49..47643b4ba5d1 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -182,11 +182,11 @@ def allocate_const(raw_data, dtype, shape, span=None): for i in raw_data: list_data.append(i.value) nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) - n = tvm.tir.AllocateConst(self.buffer_var, dtype, shape, nd_data, self.body, span=span) + n = tvm.tir.AllocateConst(self.buffer.data, dtype, shape, nd_data, self.body, span=span) return n super().__init__(allocate_const, concise_scope=True, def_symbol=True) - self.buffer_var = None + self.buffer = None def enter_scope( self, @@ -210,13 +210,17 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var(data, dtype, shape, span: Span = None): + def setup_buffer(data, dtype, shape, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + self.buffer = tvm.tir.decl_buffer( + shape=shape, + dtype=dtype, + name=name, + span=span, + ) - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) @register diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index bdc2165b8742..8f4b9caa3e5d 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1011,7 +1011,20 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { return Doc(); } -Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { +namespace { +struct AllocUsage { + Buffer alloc_buffer; + Array aliasing_buffers; +}; + +template +AllocUsage find_allocate_usage(AllocNode* op, Map>* cache_ptr) { + Map>& cache = *cache_ptr; + if (!cache.count(op->buffer_var)) { + cache = BufferUsageFinder::FindUsage(std::move(cache), op->body); + } + Array buffer_usage = cache.Get(op->buffer_var).value_or({}); + auto is_exact_match = [](Buffer a, Buffer b) { if (a->dtype != b->dtype) return false; if (a->shape.size() != b->shape.size()) return false; @@ -1025,28 +1038,32 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { return true; }; - if (!buffer_var_usage_.count(op->buffer_var)) { - buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); - } - Array buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({}); - // If the buffer allocated via T.allocate is an exact match to the // usage of the buffer later on, then that buffer is the return // value of T.allocate, and no T.buffer_decl statement is needed. - Buffer alloc_buf(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0, 0, - kDefault); + Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0, + 0, kDefault); bool found_alloc_buf = false; Array aliasing_buffers; for (const auto& buf : buffer_usage) { - if (!found_alloc_buf && is_exact_match(buf, alloc_buf)) { - alloc_buf = buf; + if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) { + alloc_buffer = buf; found_alloc_buf = true; } else { aliasing_buffers.push_back(buf); } } - buf_not_in_headers_.insert(alloc_buf.get()); - var_not_in_headers_.insert(alloc_buf->data.get()); + + return AllocUsage{alloc_buffer, aliasing_buffers}; +} +} // namespace + +Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { + auto usage = find_allocate_usage(op, &buffer_var_usage_); + Buffer& alloc_buffer = usage.alloc_buffer; + Array& aliasing_buffers = usage.aliasing_buffers; + buf_not_in_headers_.insert(alloc_buffer.get()); + var_not_in_headers_.insert(alloc_buffer->data.get()); auto storage_scope = GetPtrStorageScope(op->buffer_var); Doc func_call; @@ -1064,11 +1081,11 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << func_call << " as " << Print(alloc_buf) << ":"; + doc << "with " << func_call << " as " << Print(alloc_buffer) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body)); } else { - doc << Print(alloc_buf) << " = " << func_call << Doc::NewLine(); + doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine(); doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); @@ -1105,16 +1122,25 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { } auto ndarray_str = ss.str(); + auto usage = find_allocate_usage(alloc, &buffer_var_usage_); + Buffer& alloc_buffer = usage.alloc_buffer; + Array& aliasing_buffers = usage.aliasing_buffers; + buf_not_in_headers_.insert(alloc_buffer.get()); + var_not_in_headers_.insert(alloc_buffer->data.get()); + + Doc func_call; + func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) + << ", " << Print(alloc->extents) << ")"; + Doc doc; var_not_in_headers_.insert(alloc->buffer_var.get()); if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", " - << Print(alloc->extents) << ")"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + doc << "with " << func_call << " as " << Print(alloc_buffer) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers) + << PrintBody(alloc->body)); } else { - doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", " - << PrintDType(alloc->dtype) << ", " << Print(alloc->extents); - doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body); } return doc; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 540943802763..05f3be7cf3a3 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -36,8 +36,8 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # buffer definition C_global = T.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) - A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = T.match_buffer(B, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body T.realize(packedB[0:32, 0:1024, 0:32], "") @@ -90,16 +90,14 @@ class Module: def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=128, offset_factor=1) # body packedB = T.allocate([32768], "float32", "global") for x in T.parallel(0, 32): for y in T.serial(0, 1024): - packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[ - T.ramp(((y * 1024) + (x * 32)), 1, 32) - ] + packedB[T.ramp(((x * 32768) + (y * 32)), 1, 32)] = B_1[y, T.ramp(x * 32, 1, 32)] for x_outer in T.parallel(0, 32): C_global = T.allocate([1024], "float32", "global") for y_outer in T.serial(0, 32): @@ -208,7 +206,7 @@ def mmult( A_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A_data, "storage_alignment", 128) - A: T.Buffer = T.buffer_decl([1024, 1024], dtype="int32", data=A_data) + A: T.Buffer = T.buffer_decl([1024 * 1024], dtype="int32", data=A_data) buf0_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 2, dtype="handle") buf0_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf0_shape_data) buf0_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg0, 0, 3, dtype="handle") @@ -218,7 +216,7 @@ def mmult( B_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") T.attr(B_data, "storage_alignment", 128) - B: T.Buffer = T.buffer_decl([1024, 1024], dtype="int32", data=B_data) + B: T.Buffer = T.buffer_decl([1024 * 1024], dtype="int32", data=B_data) buf1_shape_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 2, dtype="handle") buf1_shape: T.Buffer = T.buffer_decl([2], dtype="int32", data=buf1_shape_data) buf1_strides_data: T.Ptr[T.int32] = T.tvm_struct_get(arg1, 0, 3, dtype="handle") @@ -937,19 +935,17 @@ def func(A: T.handle, W: T.handle, Conv: T.handle) -> None: def opt_conv_tensorcore_lower(): @T.prim_func - def func(A: T.handle, W: T.handle, Conv: T.handle) -> None: + def func( + A: T.Buffer[(16, 14, 14, 16, 16, 16), "float16"], + W: T.Buffer[(3, 3, 16, 32, 16, 16), "float16"], + Conv: T.Buffer[(16, 14, 14, 32, 16, 16), "float32"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body - A_1 = T.match_buffer( - A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - W_1 = T.match_buffer( - W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 - ) - Conv_1 = T.match_buffer( - Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 - ) + A_1 = T.buffer_decl([12845056], dtype="float16", data=A.data) + W_1 = T.buffer_decl([1179648], dtype="float16", data=W.data) + Conv_1 = T.buffer_decl([25690112], data=Conv.data) bx = T.env_thread("blockIdx.x") by = T.env_thread("blockIdx.y") bz = T.env_thread("blockIdx.z") @@ -2473,7 +2469,7 @@ def opt_conv_tensorcore_mod_host( def vthread_func(): @T.prim_func def vthread_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") + A = T.match_buffer(a, [256], "float32") C = T.match_buffer(c, [256], "float32") i0 = T.env_thread("blockIdx.x") @@ -2704,7 +2700,7 @@ def block_elements(a: T.handle, b: T.handle) -> None: D = T.match_buffer(A[0:4, 0], (4, 1)) with T.init(): B[0, 0] = T.float32(0) - B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0] return block_elements @@ -2762,16 +2758,7 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 -def rank0(): - @T.prim_func - def rank0(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - A[()] = 2 - B[()] = A[()] - - -def Module4(): +def module_const(): @tvm.script.ir_module class Module4: # There is an ongoing (python)dict->(c++)Map->(python)dict issue which potentially @@ -2807,11 +2794,11 @@ def B(a: T.handle, c: T.handle) -> None: K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K1, x) + B[x] = A[x] + K1[x] K2 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = B[x] + T.load("int32", K2, x) + B[x] = B[x] + K2[x] for x in T.serial(0, 10): C[x] = B[x] @@ -2819,12 +2806,6 @@ def B(a: T.handle, c: T.handle) -> None: return Module4 -def test_module_const(): - func = Module4() - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - def constant(): @T.prim_func def constant(a: T.handle, c: T.handle) -> None: @@ -2833,7 +2814,7 @@ def constant(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] for x in T.serial(0, 10): C[x] = B[x] @@ -2841,37 +2822,15 @@ def constant(a: T.handle, c: T.handle) -> None: return constant -def test_const(): - func = constant() - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - -@T.prim_func -def rank0(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - A[()] = 2 - B[()] = A[()] - - -def test_rank0_buffers(): - func = rank0 - rt_func = tvm.script.from_source(func.script(show_meta=True)) - tvm.ir.assert_structural_equal(func, rt_func) - - -@T.prim_func -def rank0_block(a: T.handle) -> None: - A = T.match_buffer(a, (), "float32") - B = T.alloc_buffer((), "float32") - T.store(B.data, 0, T.load("float32", A.data, 0)) +def rank0(): + @T.prim_func + def rank0(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") + A[()] = 2 + B[()] = A[()] - with T.block("update") as []: - T.reads([A[()]]) - T.writes([B[()]]) - for i in range(1): - B[()] = A[()] + return rank0 def rank0_block(): @@ -2999,7 +2958,7 @@ def primfunc_with_allocate_annotations(): def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) @@ -3025,7 +2984,7 @@ def comm_reducer_single_reduce_group(): def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128, 128], dtype="float32") + A = T.match_buffer(a, [128 * 128], dtype="float32") for i in T.serial(0, 128): T.launch_thread(threadIdx_x, 128) reduce_temp0 = T.allocate([1], "float32", "local") @@ -3040,7 +2999,7 @@ def comm_reducer_multiple_reduce_groups(): def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) threadIdx_x = T.env_thread("threadIdx.x") - A = T.match_buffer(a, [128, 128], dtype="float32") + A = T.match_buffer(a, [128 * 128], dtype="float32") for i in T.serial(0, 128): T.launch_thread(threadIdx_x, 128) reduce_temp0 = T.allocate([1], "float32", "local") @@ -3206,6 +3165,8 @@ def func_T_ptr_allocate() -> None: opt_conv_tensorcore_mod_host, vthread_func, matmul, + module_const, + constant, rank0, rank0_block, select, From f29d4171b26e526f0bce0b6dd96a4cef497be1a4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Feb 2022 10:56:05 -0600 Subject: [PATCH 139/177] Fix lint errors. --- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 3fb1334801a4..5f0b9fe3b690 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -315,9 +315,6 @@ def EncodeConstants(const_dict): """ new_const_dict = {} - pointer_to_buffer = {} - rewrite_buffer = {} - rewrite_pointer = {} def collect_encoding_definitions(stmt, old_buffer_to_const): # Map from copy destination to copy source. From bf65156aed3236eeedcb2baf368017553a97ef6c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Feb 2022 14:25:28 -0600 Subject: [PATCH 140/177] Further updates for ethos-u tests. --- .../tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py | 4 ++-- src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc | 3 --- tests/python/contrib/test_ethosu/infra.py | 4 ++-- .../contrib/test_ethosu/test_replace_unary_elementwise.py | 4 ++-- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 9b1bbf4d03c0..c00151774953 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -122,9 +122,9 @@ def analyze_pool_access(stmt): if isinstance(stmt, tvm.tir.stmt.LetStmt): call_address_of = stmt.value load = call_address_of.args[0] - pool_var = load.buffer_var + pool_var = load.buffer.data scratch_region_map[stmt.var] = RegionOffset( - region=pool_var_region_map[pool_var], offset=int(load.index) + region=pool_var_region_map[pool_var], offset=int(load.indices[0]) ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_pool_access) diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index f27b53c5351b..b73534090ab5 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -176,9 +176,6 @@ Optional PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( const PrimFunc& original_func) { - ICHECK_EQ(original_func->preflattened_buffer_map.size(), 0) - << "ConvertPoolAllocationsToOffsets pass expects to operate on pre-flattened buffers, prior " - "to StorageFlatten (TE schedules) or FlattenBuffers (TIR schedules)"; ScopeInfo si; Optional resource_handle = GetResourceHandle(original_func); diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 70932a23a532..2ad6f3039631 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -519,8 +519,8 @@ def get_pooling_args(call, include_buffers=False): for i, arg in enumerate(args): if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): pooling_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - pooling_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + pooling_args.append(arg.indices[0]) else: pooling_args.append(arg) diff --git a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py index e1c633e1d569..498609fb15b7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py +++ b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py @@ -33,8 +33,8 @@ def _get_unary_elementwise_args(call, include_buffers=False, remove_constants=Fa for i, arg in enumerate(args): if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): unary_elementwise_args.append(arg.value) - elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: - unary_elementwise_args.append(arg.index) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + unary_elementwise_args.append(arg.indices[0]) else: unary_elementwise_args.append(arg) From 2c60f512ec8bcf84163b5dd6a2d768a56bb9ac60 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Feb 2022 15:02:40 -0600 Subject: [PATCH 141/177] Updated ethos.u buffer sizes in test. --- .../contrib/test_ethosu/test_scheduler.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 40f42df9eaad..5c6f064873ef 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -190,17 +190,18 @@ def main(input_buffer: T.Buffer[(301056,), "int8"], output_buffer: T.Buffer[(752 weight_buffer2 = T.buffer_decl([736], "uint8") bias_buffer2 = T.buffer_decl([240], "uint8") - placeholder_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) - placeholder_d_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) + weight_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) + weight_global2 = T.buffer_decl([736], "uint8", data=weight_global.data) + bias_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) featuremap_buffer = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) featuremap_buffer2 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", weight_buffer[0], 2608, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", bias_buffer[0], 240, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 2608, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", weight_buffer2[0], 736, placeholder_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", bias_buffer2[0], 240, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 736, 12, placeholder_d_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", weight_buffer[0], 2608, weight_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", bias_buffer[0], 240, bias_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global[0], 2608, 12, bias_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", weight_buffer2[0], 736, weight_global2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", bias_buffer2[0], 240, bias_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global2[0], 736, 12, bias_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, output_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) __tvm_meta__ = None # fmt: on From a79b0acbd28e2a447d6cb214f722653ebc2fdf23 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Feb 2022 15:55:35 -0600 Subject: [PATCH 142/177] Updated tir.BindParams to use BufferLoad instead of Load. --- src/tir/transforms/bind_params.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 944a67a879fd..1d2b2db207bd 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -53,12 +53,11 @@ class ParamsCollector : public StmtExprVisitor { return constant_list_; } - void VisitExpr_(const LoadNode* ln) { - if (constant_map_.find(ln->buffer_var) != constant_map_.end()) { - auto it = - std::find(constant_list_.begin(), constant_list_.end(), ln->buffer_var.operator->()); + void VisitExpr_(const BufferLoadNode* ln) { + if (constant_map_.find(ln->buffer->data) != constant_map_.end()) { + auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get()); if (it == constant_list_.end()) { - constant_list_.push_back(ln->buffer_var.operator->()); + constant_list_.push_back(ln->buffer->data.get()); } } StmtExprVisitor::VisitExpr_(ln); From 62c3f901679e74d556ce8543b7b29ec67e96cedb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Feb 2022 09:25:50 -0600 Subject: [PATCH 143/177] Updated topi.cuda.scan implementation to follow buffer dimensions. --- python/tvm/topi/cuda/scan.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0d19a92f2058..84f40721c6bb 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -60,9 +60,17 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i your operation. """ + shape = data.shape batch_size = prod(data.shape[:-1]) scan_axis_size = data.shape[-1] + def indices(index): + out = [] + for dim in reversed(shape): + out.append(tvm.tir.indexmod(index, dim)) + index = tvm.tir.indexdiv(index, dim) + return reversed(out) + ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) @@ -95,7 +103,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * nthread_tx + tx with ib.if_scope(tid < scan_axis_size): - output[by * scan_axis_size + tid] = cast(data[by * scan_axis_size + tid], out_dtype) + output[indices(by * scan_axis_size + tid)] = cast( + data[indices(by * scan_axis_size + tid)], out_dtype + ) nthread_tx = max_threads nthread_bx = ceil_div(scan_axis_size, max_threads) @@ -129,9 +139,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.te.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): - output[by * scan_axis_size + end[0] - 1] = binop( - output[by * scan_axis_size + end[0] - 1], - output[by * scan_axis_size + middle[0] - 1], + output[indices(by * scan_axis_size + end[0] - 1)] = binop( + output[indices(by * scan_axis_size + end[0] - 1)], + output[indices(by * scan_axis_size + middle[0] - 1)], ) # Down Sweep of exclusive scan @@ -140,8 +150,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): if reduction is not None: - reduction[bx] = output[(bx + 1) * scan_axis_size - 1] - output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype) + reduction[indices(bx)[:-1]] = output[indices((bx + 1) * scan_axis_size - 1)] + output[indices((bx + 1) * scan_axis_size - 1)] = cast(identity_value, out_dtype) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) @@ -168,12 +178,12 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.tir.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): - tmp[0] = output[by * scan_axis_size + middle[0] - 1] - output[by * scan_axis_size + middle[0] - 1] = output[ - by * scan_axis_size + end[0] - 1 + tmp[0] = output[indices(by * scan_axis_size + middle[0] - 1)] + output[indices(by * scan_axis_size + middle[0] - 1)] = output[ + indices(by * scan_axis_size + end[0] - 1) ] - output[by * scan_axis_size + end[0] - 1] = binop( - output[by * scan_axis_size + end[0] - 1], tmp[0] + output[indices(by * scan_axis_size + end[0] - 1)] = binop( + output[indices(by * scan_axis_size + end[0] - 1)], tmp[0] ) return ib.get() From 07dc8ab9c801f460619352dd5227e8f1045dcb43 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Feb 2022 10:29:21 -0600 Subject: [PATCH 144/177] Resolved breakage when flattening AllocateConst nodes. --- src/tir/transforms/storage_flatten.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 822746333fde..a92ef43f7afc 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -501,6 +501,11 @@ class BufferStrideLegalize : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const AllocateConstNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { Buffer key = op->buffer; Buffer with_strides = WithStrides(op->buffer); @@ -847,6 +852,11 @@ class BufferBindUnwrapper : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const AllocateConstNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); @@ -1299,6 +1309,11 @@ class StorageFlattener : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const AllocateConstNode* op) final { + allocate_node_var_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { const auto& key = op->buffer; From cc1f3ae7e813f6c81508cce297608cdf6ec42229 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Feb 2022 12:28:12 -0600 Subject: [PATCH 145/177] Resolved breakages from latest merge with main. --- .../test_tir_transform_extract_constants.py | 6 +- .../test_tir_transform_loop_partition.py | 66 +++++++------------ .../test_tir_transform_narrow_datatype.py | 12 ++-- .../unittest/test_tvmscript_error_report.py | 3 +- 4 files changed, 32 insertions(+), 55 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py index 74144f252ade..9636a9bdde4c 100644 --- a/tests/python/unittest/test_tir_transform_extract_constants.py +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -28,7 +28,7 @@ def constant1(a: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] @T.prim_func def constant2(a: T.handle) -> None: @@ -36,7 +36,7 @@ def constant2(a: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] @T.prim_func def constant3(a: T.handle) -> None: @@ -44,7 +44,7 @@ def constant3(a: T.handle) -> None: B = T.alloc_buffer((10), "int32") K = T.allocate_const([1, 2, 3, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K, x) + B[x] = A[x] + K[x] def test_const_extraction(): diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index c07772de2d40..6cfe96664d89 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -570,64 +570,42 @@ def test_explicit_partition_hint(): @T.prim_func def partitioned_concat_3( - placeholder: T.Buffer[(1, 64, 28, 28), "int8"], - placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], - placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], - T_concat: T.Buffer[(1, 128, 28, 28), "int8"], + placeholder: T.Buffer[(50176,), "int8"], + placeholder_1: T.Buffer[(25088,), "int8"], + placeholder_2: T.Buffer[(25088,), "int8"], + T_concat: T.Buffer[(100352,), "int8"], ) -> None: + T.preflattened_buffer(placeholder, [1, 64, 28, 28], "int8", data=placeholder.data) + T.preflattened_buffer(placeholder_1, [1, 32, 28, 28], "int8", data=placeholder_1.data) + T.preflattened_buffer(placeholder_2, [1, 32, 28, 28], "int8", data=placeholder_2.data) + T.preflattened_buffer(T_concat, [1, 128, 28, 28], "int8", data=T_concat.data) for i1, i2, i3 in T.grid(64, 28, 28): - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3 + 50176, - T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3] for i1, i2, i3 in T.grid(32, 28, 28): - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3 + 75264, - T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2[i1 * 784 + i2 * 28 + i3] @T.prim_func def concat_func_3( - placeholder: T.Buffer[(1, 64, 28, 28), "int8"], - placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"], - placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"], - T_concat: T.Buffer[(1, 128, 28, 28), "int8"], + placeholder: T.Buffer[(50176,), "int8"], + placeholder_1: T.Buffer[(25088,), "int8"], + placeholder_2: T.Buffer[(25088,), "int8"], + T_concat: T.Buffer[(100352,), "int8"], ) -> None: + T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data) + T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data) + T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data) + T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3 - 75264), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264] if 64 <= i1 and i1 < 96: - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3 - 50176), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176] if i1 < 64: - T.store( - T_concat.data, - i1 * 784 + i2 * 28 + i3, - T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3), - True, - ) + T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] def test_condition_mutually_exclusive(): diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 0cd3e138bab1..51c382309856 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -51,9 +51,9 @@ def lower_sch(sch, args, target_bits, extra_passes=None): def test_basic(): def check(m, n, target_bits, target_dtype): ib = tvm.tir.ir_builder.create() - Ab = tvm.tir.decl_buffer((m, n), name="A") + Ab = tvm.tir.decl_buffer([m * n], name="A") A = ib.buffer_ptr(Ab) - Bb = tvm.tir.decl_buffer((m, n), name="B") + Bb = tvm.tir.decl_buffer([m * n], name="B") B = ib.buffer_ptr(Bb) with ib.for_range(0, m, name="i") as i: with ib.for_range(0, n, name="j") as j: @@ -83,9 +83,9 @@ def check(m, n, target_bits, target_dtype): def test_thread_axis(): def check(m, n, target_bits, target_dtype): ib = tvm.tir.ir_builder.create() - Ab = tvm.tir.decl_buffer((m, n), name="A") + Ab = tvm.tir.decl_buffer([m * n], name="A") A = ib.buffer_ptr(Ab) - Bb = tvm.tir.decl_buffer((m, n), name="B") + Bb = tvm.tir.decl_buffer([m * n], name="B") B = ib.buffer_ptr(Bb) bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") @@ -168,9 +168,9 @@ def test_slice(): def check(m, n, target_bits, target_dtype): # The index may overflow in B, while not in A ib = tvm.tir.ir_builder.create() - Ab = tvm.tir.decl_buffer((m, n), name="A") + Ab = tvm.tir.decl_buffer([m * n], name="A") A = ib.buffer_ptr(Ab) - Bb = tvm.tir.decl_buffer((m, n * 2), name="B") + Bb = tvm.tir.decl_buffer([m * n * 2], name="B") B = ib.buffer_ptr(Bb) with ib.for_range(0, m, name="i") as i: with ib.for_range(0, n, name="j") as j: diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index bc66f6ddd90f..462142e2e534 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -332,8 +332,7 @@ def opaque_access_during_complete(a: T.handle) -> None: # error A = T.match_buffer(a, (16, 16), "float32") for i, j in T.grid(16, 16): with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.evaluate(A[vi * 16 + vj]) + T.evaluate(T.call_extern("dummy_extern_function", A.data, dtype="int32")) def test_opaque_access_during_complete(): From 09d33bb83a7d0cda8a250cba7e738004804491d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Feb 2022 12:30:02 -0600 Subject: [PATCH 146/177] Corrected error in merge. --- python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index fafbf8fad5da..33a22d1a09fb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -425,6 +425,8 @@ def replace_npu_fm_with_address(npu_fm): return npu_fm def replace_npu_address_range_with_address(npu_addr_range): + assert isinstance(npu_addr_range.address, tvm.tir.BufferLoad) + buffer = npu_addr_range.address.buffer.data index = int( npu_addr_range.address.indices[0] * (np.iinfo(np.dtype(npu_addr_range.address)).bits // 8) From 24297e35cee3b41e8a9e0dc68745081762e9b6ed Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Feb 2022 16:00:07 -0600 Subject: [PATCH 147/177] Use empty indices for rank-0 tensor. --- include/tvm/topi/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 04027f8974fe..36f490798a8e 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1815,7 +1815,7 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& indices) { PrimExpr ret = default_value; if (0 == rank_sparse_indices) { - ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret); + ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret); } else if (1 == rank_sparse_indices) { for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); From 14676bbc1ddedcc3039dede280ff733ef2e90b43 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 26 Feb 2022 06:36:43 -0600 Subject: [PATCH 148/177] Added ir_builder workaround for 1-d indexing. --- python/tvm/tir/ir_builder.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 092adf6901c6..928e5007df61 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Developer API of IR node builder make function.""" +import tvm from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const from tvm.ir import container as _container @@ -97,6 +98,12 @@ def _normalize_index(self, index): index = [x.var if isinstance(x, _expr.IterVar) else x for x in index] + # Workaround to support previous behavior of ir_builder + # indexing by a single index, treating the buffer as if were + # already flattened. + if len(index) == 1 and len(self._buffer.shape) != 1: + index = tvm.topi.utils.unravel_index(index[0], self._buffer.shape) + return index def __getitem__(self, index): From 03f916434862af23d5d557b374fe7f55a74e484a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 26 Feb 2022 17:14:07 -0600 Subject: [PATCH 149/177] Consistent buffer access type in LLVM codegen, to match C codegen --- src/target/llvm/codegen_llvm.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 028b95a4f037..2cb7d612b5d1 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -989,10 +989,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations."; PrimExpr index = load->indices[0]; if (const RampNode* r = index.as()) { - index = r->base / make_const(DataType::Int(32), r->lanes); + index = r->base; } TypedPointer buffer_ptr = - CreateBufferPtr(load->dtype, MakeValue(load->buffer->data), MakeValue(index)); + CreateBufferPtr(load->buffer->dtype, MakeValue(load->buffer->data), MakeValue(index)); unsigned addrspace = llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); @@ -1263,7 +1263,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { if (t.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1308,7 +1308,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { int basic_align = t.bits() / 8; llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(basic_align), is_volatile); @@ -1405,7 +1405,7 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { if (value_dtype.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(value_dtype, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1446,7 +1446,7 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { // scalarized store. int basic_align = value_dtype.bits() / 8; auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(value_dtype.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, From 6d58d23146549ef390b01b407c9cb28c3ba4865e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 26 Feb 2022 17:15:28 -0600 Subject: [PATCH 150/177] StorageRewrite, update indices of modified buffers. --- src/tir/transforms/storage_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0a524ae024a9..8acfe6f3ede4 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1595,7 +1595,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - return PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } From 0a9ebe6cbb795bd5f94ec593cc965ed51ecde2f8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Feb 2022 21:32:38 -0600 Subject: [PATCH 151/177] Dynamic relay nodes, access 0-d tensors with 0-d indices. --- python/tvm/relay/op/dyn/_algorithm.py | 6 +++--- python/tvm/relay/op/dyn/_transform.py | 2 +- python/tvm/relay/op/dyn/nn/_nn.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/op/dyn/_algorithm.py b/python/tvm/relay/op/dyn/_algorithm.py index ba903e680bbd..24550e6d6b6a 100644 --- a/python/tvm/relay/op/dyn/_algorithm.py +++ b/python/tvm/relay/op/dyn/_algorithm.py @@ -42,12 +42,12 @@ def _topk_shape_func_input_data(data, k, axis): val_out[i] = int64(data.shape[i]) indices_out[i] = int64(data.shape[i]) else: - if k[0] < 1: + if k[()] < 1: val_out[i] = int64(data.shape[i]) indices_out[i] = int64(data.shape[i]) else: - val_out[i] = int64(k[0]) - indices_out[i] = int64(k[0]) + val_out[i] = int64(k[()]) + indices_out[i] = int64(k[()]) return val_out, indices_out diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index c909764319d9..3300c75909cd 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -170,7 +170,7 @@ def _onehot_shape_func(dshape, k, axis): out = output_tensor((ndim,), "int64") for i in const_range(axis): out[i] = int64(dshape[i]) - out[axis] = int64(k[0]) + out[axis] = int64(k[()]) for j in const_range(axis + 1, ndim): out[j] = int64(dshape[j - 1]) return out diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index 727715141230..ec4066561fce 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -78,8 +78,8 @@ def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis): out = output_tensor((4,), "int64") for i in const_range(4): out[i] = int64(dshape[i]) - out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) - out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[()])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[()])) return out @@ -108,9 +108,9 @@ def _upsampling3d_shape_func( out = output_tensor((5,), "int64") for i in const_range(5): out[i] = int64(dshape[i]) - out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[0])) - out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) - out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[()])) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[()])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[()])) return out From 99357d300db77e5b7bf83e0ba5678c92e0836cda Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Feb 2022 21:34:31 -0600 Subject: [PATCH 152/177] BFloat16 legalization, update buffer type. --- src/tir/transforms/bf16_legalize.cc | 46 +++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 4ae5d1ea5584..193584f84b47 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -199,11 +199,11 @@ class BF16LowerRewriter : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - auto it = buffer_remap_.find(op->buffer); - if (it != buffer_remap_.end()) { - return BufferStore(it->second, op->value, op->indices); - } else { + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return ret; + } else { + return BufferStore(new_buf, op->value, op->indices); } } @@ -229,11 +229,11 @@ class BF16LowerRewriter : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - auto it = buffer_remap_.find(op->buffer); - if (it != buffer_remap_.end()) { - return BufferRealize(it->second, op->bounds, op->condition, op->body); - } else { + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return ret; + } else { + return BufferRealize(new_buf, op->bounds, op->condition, op->body); } } @@ -246,11 +246,11 @@ class BF16LowerRewriter : public StmtExprMutator { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); - auto it = buffer_remap_.find(op->buffer); - if (it != buffer_remap_.end()) { - return BufferLoad(it->second, op->indices); - } else { + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { return ret; + } else { + return BufferLoad(new_buf, op->indices); } } @@ -322,6 +322,28 @@ class BF16LowerRewriter : public StmtExprMutator { } private: + Buffer GetRemappedBuffer(Buffer buf) { + auto buf_it = buffer_remap_.find(buf); + if (buf_it != buffer_remap_.end()) { + return buf_it->second; + } + + Buffer new_buf = buf; + + auto var_it = var_remap_.find(buf->data); + if (var_it != var_remap_.end()) { + DataType dtype = + buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : buf->dtype; + new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name, + buf->data_alignment, buf->offset_factor, buf->buffer_type, + buf->axis_separators, buf->span); + } + + buffer_remap_[buf] = new_buf; + + return new_buf; + } + std::unordered_map buffer_remap_; std::unordered_map var_remap_; }; From 9dd8afbe63c3111fb8f54eae4a206268f78403c9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Feb 2022 22:13:38 -0600 Subject: [PATCH 153/177] Updated meshgrid to use 0-d index for 0-d buffer. --- include/tvm/topi/transform.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 36f490798a8e..ef36c015957a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1566,7 +1566,11 @@ inline Array meshgrid(const Array& inputs, const std::string& in out_shape, [&](const Array& indices) { const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; - Array real_indices = {indices[src_index]}; + auto ndim = inputs[i]->GetShape().size(); + Array real_indices = {}; + if (ndim > 0) { + real_indices = {indices[src_index]}; + } return inputs[i](real_indices); }, name, tag)); From bf2cc9e01c60e99152db8ccb50b1296ffdeaaddb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 28 Feb 2022 08:34:11 -0600 Subject: [PATCH 154/177] Corrected boolean handling in Allocate nodes. --- src/tir/ir/stmt.cc | 3 ++- src/tir/transforms/storage_flatten.cc | 11 +++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e66c85731cec..3914f41e4f34 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -343,7 +343,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { - CHECK(IsPointerType(buffer_var->type_annotation, dtype)) + CHECK(IsPointerType(buffer_var->type_annotation, dtype) || + (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" << buffer_var->type_annotation diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index a92ef43f7afc..2bc081483ccd 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1281,7 +1281,8 @@ class StorageFlattener : public StmtExprMutator { PrimExpr value = op->value; if (value.dtype() == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) - << "Expected int8 backing array for boolean tensor"; + << "Expected int8 backing array for boolean tensor, but received " + << e.flattened_buffer->dtype; value = tir::Cast(DataType::Int(8), value); } @@ -1408,7 +1409,8 @@ class StorageFlattener : public StmtExprMutator { if (op->dtype == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) - << "Expected int8 backing array for boolean tensor"; + << "Expected int8 backing array for boolean tensor, but received " + << e.flattened_buffer->dtype; val = tir::Cast(DataType::Bool(), val); } @@ -1531,6 +1533,11 @@ class StorageFlattener : public StmtExprMutator { BufferEntry entry; entry.buffer = buffer; entry.flattened_buffer = buffer.GetFlattenedBuffer(); + // Boolean tensors are backed by a Int8 array. + if (entry.flattened_buffer->dtype == DataType::Bool()) { + auto writer = entry.flattened_buffer.CopyOnWrite(); + writer->dtype = DataType::Int(8); + } buf_map_[buffer] = std::move(entry); } From 8dbc57100b8a8495c8af39c5dc886f85e1431396 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 28 Feb 2022 09:22:23 -0600 Subject: [PATCH 155/177] Added workaround to unpack 1-d Tensor indices into N-d buffer indices. --- python/tvm/relay/op/dyn/_algorithm.py | 6 ++-- .../schedule/schedule_postproc_to_primfunc.cc | 28 +++++++++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/dyn/_algorithm.py b/python/tvm/relay/op/dyn/_algorithm.py index 24550e6d6b6a..ba903e680bbd 100644 --- a/python/tvm/relay/op/dyn/_algorithm.py +++ b/python/tvm/relay/op/dyn/_algorithm.py @@ -42,12 +42,12 @@ def _topk_shape_func_input_data(data, k, axis): val_out[i] = int64(data.shape[i]) indices_out[i] = int64(data.shape[i]) else: - if k[()] < 1: + if k[0] < 1: val_out[i] = int64(data.shape[i]) indices_out[i] = int64(data.shape[i]) else: - val_out[i] = int64(k[()]) - indices_out[i] = int64(k[()]) + val_out[i] = int64(k[0]) + indices_out[i] = int64(k[0]) return val_out, indices_out diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index e8cd0b387f90..0cf6e54391da 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -121,7 +121,7 @@ class TensorToBufferMapper : public StmtExprMutator { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - return BufferStore(buffer, op->value, op->indices); + return BufferStore(buffer, op->value, GetIndices(op->indices, buffer->shape)); } PrimExpr VisitExpr_(const ProducerLoadNode* op) final { @@ -129,7 +129,7 @@ class TensorToBufferMapper : public StmtExprMutator { op = ret.as(); Tensor tensor = Downcast(op->producer); Buffer buffer = GetBuffer(tensor); - return tir::BufferLoad(buffer, op->indices); + return tir::BufferLoad(buffer, GetIndices(op->indices, buffer->shape)); } private: @@ -147,6 +147,30 @@ class TensorToBufferMapper : public StmtExprMutator { return buffer; } + Array GetIndices(const Array& tensor_indices, + const Array& buffer_shape) { + if (tensor_indices.size() == buffer_shape.size()) { + return tensor_indices; + } else if (tensor_indices.size() == 1) { + // Workaround to support previous behavior of tensor indexing by + // a single index, treating the tensor as if were already + // flattened by a row-major traversal. + PrimExpr unravel = tensor_indices[0]; + Array rev_indices; + for (size_t i = buffer_shape.size(); i > 0; i--) { + PrimExpr dim = buffer_shape[i - 1]; + rev_indices.push_back(indexmod(unravel, dim)); + unravel = indexdiv(unravel, dim); + } + return Array(rev_indices.rbegin(), rev_indices.rend()); + } else { + LOG(FATAL) << "Cannot produce indices for " << buffer_shape.size() + << "-dimensional TIR buffer using " << tensor_indices.size() + << "-dimensional tensor indices."; + return {}; + } + } + // Maps tensor to buffer. std::unordered_map buffer_map_; }; From f6deec11be96d26476b45dd0272bf96b49c0060b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 28 Feb 2022 10:57:53 -0600 Subject: [PATCH 156/177] Resolved a few more failures in relay tests on cuda. --- python/tvm/topi/cuda/scan.py | 31 +++++++------------- python/tvm/topi/utils.py | 9 ++++-- src/tir/transforms/lower_thread_allreduce.cc | 9 ++++-- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 84f40721c6bb..31989742c567 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -64,13 +64,6 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i batch_size = prod(data.shape[:-1]) scan_axis_size = data.shape[-1] - def indices(index): - out = [] - for dim in reversed(shape): - out.append(tvm.tir.indexmod(index, dim)) - index = tvm.tir.indexdiv(index, dim) - return reversed(out) - ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) @@ -103,9 +96,7 @@ def indices(index): ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * nthread_tx + tx with ib.if_scope(tid < scan_axis_size): - output[indices(by * scan_axis_size + tid)] = cast( - data[indices(by * scan_axis_size + tid)], out_dtype - ) + output[by * scan_axis_size + tid] = cast(data[by * scan_axis_size + tid], out_dtype) nthread_tx = max_threads nthread_bx = ceil_div(scan_axis_size, max_threads) @@ -139,9 +130,9 @@ def indices(index): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.te.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): - output[indices(by * scan_axis_size + end[0] - 1)] = binop( - output[indices(by * scan_axis_size + end[0] - 1)], - output[indices(by * scan_axis_size + middle[0] - 1)], + output[by * scan_axis_size + end[0] - 1] = binop( + output[by * scan_axis_size + end[0] - 1], + output[by * scan_axis_size + middle[0] - 1], ) # Down Sweep of exclusive scan @@ -150,8 +141,8 @@ def indices(index): ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): if reduction is not None: - reduction[indices(bx)[:-1]] = output[indices((bx + 1) * scan_axis_size - 1)] - output[indices((bx + 1) * scan_axis_size - 1)] = cast(identity_value, out_dtype) + reduction[bx] = output[(bx + 1) * scan_axis_size - 1] + output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) @@ -178,12 +169,12 @@ def indices(index): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.tir.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): - tmp[0] = output[indices(by * scan_axis_size + middle[0] - 1)] - output[indices(by * scan_axis_size + middle[0] - 1)] = output[ - indices(by * scan_axis_size + end[0] - 1) + tmp[0] = output[by * scan_axis_size + middle[0] - 1] + output[by * scan_axis_size + middle[0] - 1] = output[ + by * scan_axis_size + end[0] - 1 ] - output[indices(by * scan_axis_size + end[0] - 1)] = binop( - output[indices(by * scan_axis_size + end[0] - 1)], tmp[0] + output[by * scan_axis_size + end[0] - 1] = binop( + output[by * scan_axis_size + end[0] - 1], tmp[0] ) return ib.get() diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index be3df2be5f6a..34568b004a25 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -310,9 +310,12 @@ def unravel_index(idx, shape): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod indices = [] - for i in range(len(shape) - 1, -1, -1): - indices.append(idxm(idx, shape[i])) - idx = idxd(idx, shape[i]) + for dim in reversed(shape): + if dim == 0: + indices.append(0) + else: + indices.append(idxm(idx, dim)) + idx = idxd(idx, dim) indices = indices[::-1] return indices diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 93e61c3b3e30..7e09943d0185 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -227,8 +227,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - auto dummy_load = Downcast(call->args[2 + size + idx]); - buffers[idx] = dummy_load->buffer; + PrimExpr arg = call->args[2 + size + idx]; + // Loads from boolean buffers may have cast nodes inserted by + // earlier passes. + if (auto cast = arg.as()) { + arg = cast->value; + } + buffers[idx] = Downcast(arg)->buffer; } std::unordered_set reduce_set; From 77ef980fa3d22d54c10d0d6312ae1b6e71c5a9d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 28 Feb 2022 18:52:52 -0600 Subject: [PATCH 157/177] Resolve linting --- python/tvm/topi/cuda/scan.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 31989742c567..0d19a92f2058 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -60,7 +60,6 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i your operation. """ - shape = data.shape batch_size = prod(data.shape[:-1]) scan_axis_size = data.shape[-1] From 795c3fccf004cc99d56eb3db12a7d90a401f799f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Mar 2022 08:07:00 -0600 Subject: [PATCH 158/177] CI bump From 4703aa200200332c6cc8657b08249259114bfcf3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Mar 2022 12:07:07 -0600 Subject: [PATCH 159/177] Updated renormalize_split_pattern tests to use BufferLoad/BufferStore --- python/tvm/tir/transform/transform.py | 4 +- ...tir_transform_renormalize_split_pattern.py | 49 +++++++++++-------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 4a41d98c1c78..802fdc576c41 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -797,7 +797,7 @@ def ExtractPrimFuncConstants(): return _ffi_api.ExtractPrimFuncConstants() # type: ignore -def RenomalizeSplitPattern(): +def RenormalizeSplitPattern(): """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) Returns @@ -805,4 +805,4 @@ def RenomalizeSplitPattern(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RenormalizeSplitPattern() # type: ignore + return _ffi_api.RenormalizeSplitPattern() # type: ignore diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index eb3efd317e9c..7f60c95164a8 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -24,9 +24,12 @@ @tvm.script.ir_module class Before: @T.prim_func - def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -37,24 +40,27 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module class After: @T.prim_func - def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -65,24 +71,27 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] @tvm.script.ir_module class After_simplified: @T.prim_func - def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -93,23 +102,23 @@ def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 51 weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): - T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): - T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): - T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2] # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg # fmt: on -def tesd_renormalize_split_pattern(): - after = tvm.tir.transform.RenomalizeSplitPattern()(Before) +def test_renormalize_split_pattern(): + after = tvm.tir.transform.RenormalizeSplitPattern()(Before) tvm.ir.assert_structural_equal(after, After) after = tvm.tir.transform.Simplify()(after) tvm.ir.assert_structural_equal(after, After_simplified) From 94abb53623094dc8fc2fee07e491472d3e9ef78d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Mar 2022 14:49:05 -0600 Subject: [PATCH 160/177] Fixed cuda codegen checks for BufferStore/Ramp. --- .../unittest/test_target_codegen_cuda.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 483fdf581172..994a85095728 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re + import tvm from tvm import te import numpy as np @@ -811,14 +813,18 @@ def pre_visit(stmt): inside_broadcast[0] = True # Check Broadcast[Imm numbers] or Broadcast[Load] patterns assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.BufferLoad)) - if isinstance(stmt, tvm.tir.Store): - # Check Store[Ramp] pattern - assert isinstance(stmt.index, tvm.tir.Ramp) - if isinstance(stmt, tvm.tir.BufferLoad): - # Check Broadcast[BufferLoad] or BufferLoad[Ramp] patterns - assert inside_broadcast[0] or isinstance(stmt.indices[-1], tvm.tir.Ramp) - # Skip the rest - return stmt + + if isinstance(stmt, (tvm.tir.BufferStore, tvm.tir.BufferLoad)): + is_ramp_index = isinstance(stmt.indices[-1], tvm.tir.Ramp) + is_vectorized_buffer = re.match(r"^.*x\d+$", stmt.buffer.dtype) + if isinstance(stmt, tvm.tir.BufferLoad): + # Check Broadcast[BufferLoad] or BufferLoad[Ramp] patterns + assert inside_broadcast[0] or is_ramp_index or is_vectorized_buffer + # Skip the rest of the BufferLoad + return stmt + else: + assert is_ramp_index or is_vectorized_buffer + return None def post_visit(stmt): From 4df4be303f7c48b0d75024045b143193e2b664e1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 1 Mar 2022 16:04:35 -0600 Subject: [PATCH 161/177] Simplify indices further, needed to avoid cuda register limit. --- python/tvm/topi/utils.py | 7 ++++++- src/arith/rewrite_simplify.cc | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 34568b004a25..9934a045875a 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -310,9 +310,14 @@ def unravel_index(idx, shape): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod indices = [] - for dim in reversed(shape): + for i,dim in enumerate(reversed(shape)): if dim == 0: indices.append(0) + elif i == len(shape)-1: + # Assuming the index is in-bounds, the last coordinate is + # already less than dim, and doesn't need the be remainder + # mod dim. + indices.append(idx) else: indices.append(idxm(idx, dim)) idx = idxd(idx, dim) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index e11bd024bb22..732045384a95 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -191,7 +191,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // truc div TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div - TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); + TVM_TRY_REWRITE(floordiv(x, y) * y + floormod(x, y), x); + TVM_TRY_REWRITE(y * floordiv(x, y) + floormod(x, y), x); + TVM_TRY_REWRITE(floormod(x, y) + floordiv(x, y) * y, x); + TVM_TRY_REWRITE(floormod(x, y) + y * floordiv(x, y), x); + TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), c2.Eval()->value > 0); From 3373ecd5d4f1309a1f2ec8c584d3652c0f9d3b9d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Mar 2022 17:03:30 +0900 Subject: [PATCH 162/177] fixed dyn onehot shape func accessing 1d buffer with () --- python/tvm/relay/op/dyn/_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index 3300c75909cd..d523d43d9c64 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -170,7 +170,7 @@ def _onehot_shape_func(dshape, k, axis): out = output_tensor((ndim,), "int64") for i in const_range(axis): out[i] = int64(dshape[i]) - out[axis] = int64(k[()]) + out[axis] = int64(k[(0)]) for j in const_range(axis + 1, ndim): out[j] = int64(dshape[j - 1]) return out From 983502837bd224f62346bddde8959fd3a00d5c3e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Mar 2022 10:01:54 -0600 Subject: [PATCH 163/177] Fixed codegen indexing for int4 scalar types. --- src/target/source/codegen_c.cc | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 18fdee861402..1752c2a2e826 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -195,12 +195,16 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp std::string index_str = PrintExpr(index); if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { // This is a special case, because CodegenCUDA::PrintType() - // returns "int" for bool and for 4-bit integers. Therefore, we - // need to do the pointer arithmetic in the output's datatype, - // rather than the buffer's element type. + // returns "int" for bool and for 4-bit integers. In most cases, + // we divide by the number of lanes to determine the index. + // However, the backing type for scalar int4 and scalar bool is + // int32. Therefore, we need to divide by the ratio of their + // sizes in that case. + int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); + os << "*(" << "(" << ptr_cast(t) << vid << ")" - << " + " << index_str << " / " << t.lanes() << ")"; + << " + " << index_str << " / " << div_factor << ")"; } else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; } else { @@ -565,15 +569,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations."; - os << "(("; - this->PrintType(load->dtype.element_of(), os); - os << " *)" << this->GetVarID(load->buffer->data.get()) << " + " - << "("; - this->PrintExpr(load->indices[0], os); - if (load->dtype.bits() == 4 || (load->dtype.bits() == 1 && load->dtype.is_int())) { - os << " / " << (32 / load->dtype.bits()); - } - os << "))"; + os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))"; } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3U); os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); From 942cda1161d2abbbf06f7295fc5b0251138b1947 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 2 Mar 2022 17:36:10 -0600 Subject: [PATCH 164/177] Temporary workaround for incorrect constant folding. Need to further investigate vectorized LLVM constants --- src/relay/transforms/fold_constant.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 3a8391e05856..cebeeb3d73fe 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -252,7 +252,16 @@ class ConstantFolder : public MixedModeMutator { // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) - With fresh_build_ctx(transform::PassContext::Create()); + // With fresh_build_ctx(transform::PassContext::Create()); + + // Disabling vectorization in Eval, as temporary workaround for + // testing PR-9727. Constant folding current produces incorrect + // results, looks to be an issue with vectorized access of + // AllocateConst buffers. + auto fresh_context = transform::PassContext::Create(); + fresh_context->config.Set("tir.disable_vectorize", Bool(true)); + With with_context(fresh_context); + Map dict = (module_->attrs.defined()) ? module_->attrs->dict : Map(); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), From 7c66f239752e2ec30d86811f9e6c2b4dfb7803f2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Mar 2022 08:23:38 -0600 Subject: [PATCH 165/177] s/find_allocate_usage/FindAllocateUsage/g --- src/printer/tvmscript_printer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index e0a0ebb31828..a6e506612fb6 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1020,7 +1020,7 @@ struct AllocUsage { }; template -AllocUsage find_allocate_usage(AllocNode* op, Map>* cache_ptr) { +AllocUsage FindAllocateUsage(AllocNode* op, Map>* cache_ptr) { Map>& cache = *cache_ptr; if (!cache.count(op->buffer_var)) { cache = BufferUsageFinder::FindUsage(std::move(cache), op->body); @@ -1061,7 +1061,7 @@ AllocUsage find_allocate_usage(AllocNode* op, Map>* cache_ptr } // namespace Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - auto usage = find_allocate_usage(op, &buffer_var_usage_); + auto usage = FindAllocateUsage(op, &buffer_var_usage_); Buffer& alloc_buffer = usage.alloc_buffer; Array& aliasing_buffers = usage.aliasing_buffers; buf_not_in_headers_.insert(alloc_buffer.get()); @@ -1124,7 +1124,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { } auto ndarray_str = ss.str(); - auto usage = find_allocate_usage(alloc, &buffer_var_usage_); + auto usage = FindAllocateUsage(alloc, &buffer_var_usage_); Buffer& alloc_buffer = usage.alloc_buffer; Array& aliasing_buffers = usage.aliasing_buffers; buf_not_in_headers_.insert(alloc_buffer.get()); From e8c0e625ebb3bd049f0e425efc8a8e10699a025e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Mar 2022 08:29:54 -0600 Subject: [PATCH 166/177] Added buffer type consistency TODO. --- src/tir/ir/buffer.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 684c56ddeeb7..4fe9b162078e 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -526,6 +526,10 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array // tensors without a common datatype. Therefore, we check that the // data pointer is a pointer, but not the exact type of the // pointed-to values. + + // TODO(Lunderberg): Use an explicit pointer cast for the data + // pointer. Should be done alongside extensions to StmtExprMutator + // to more easily handle buffer/buffer_var updates. ICHECK(data->type_annotation.defined()) << "Variable " << data->name_hint << " is missing a type annotation."; ICHECK(data->type_annotation.as()) From af2adf64c7cb6a7a0d0efda39dc27cad0e147ca4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Mar 2022 08:46:22 -0600 Subject: [PATCH 167/177] Improved comment on address_of Op. --- include/tvm/tir/builtin.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 471f6ca719ab..f7e1cfbc3e6d 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -105,7 +105,12 @@ TVM_DLL const Op& large_uint_imm(); TVM_DLL const Op& q_multiply_shift(); /*! - * \brief See pseudo code + * \brief Returns the address of an element in the buffer (see pseudocode below). + * + * The number of indices should match the dimensionality of the buffer + * being accessed. If this operation occurs after buffer flattening, + * the number of indices must be supported by the target (i.e. N>1 + * only on targets that support non-flat memory buffers). * * Handle address_of(BufferLoad *op) { * return &op->buffer_var[op->indices[0], op->indices[1], ..., op->indices[N-1]]; From 50a73e115ba821d7a84b496fe53e15691a270298 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Mar 2022 08:59:09 -0600 Subject: [PATCH 168/177] Rename LegalizeDtype to LegalizeDType, made private. --- include/tvm/tir/expr.h | 24 +++++++++++++------- src/tir/ir/expr.cc | 4 ++-- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/storage_rewrite.cc | 2 +- src/tir/transforms/vectorize_loop.cc | 2 +- 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4ba27fee70b0..674ff0b7f43c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -610,14 +610,6 @@ class BufferLoadNode : public PrimExprNode { /*! \brief The indices location to be loaded. */ Array indices; - /*! \brief Set the dtype based on the buffer/indices - * - * Usually, this will be the same dtype as the buffer. This may - * have a different number of lanes than the buffer's dtype if index - * values have more than 1 lane. - */ - void LegalizeDtype(); - void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); @@ -638,6 +630,22 @@ class BufferLoadNode : public PrimExprNode { static constexpr const char* _type_key = "tir.BufferLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); + + private: + /*! \brief Set the dtype based on the buffer/indices + * + * Usually, the BufferLoad's dtype will be the same dtype as the + * buffer. This may have a different number of lanes than the + * buffer's dtype if index values have more than 1 lane. + * + * This function should only be called during construction and after + * CopyOnWrite. Friend class used here to restrict usage. + */ + void LegalizeDType(); + friend class BufferLoad; + friend class CustomDatatypesLowerer; + friend class VectorTypeRewriter; + friend class Vectorizer; }; /*! diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 04993603a3dc..ef533ef84b85 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1058,7 +1058,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); // BufferLoad -void BufferLoadNode::LegalizeDtype() { +void BufferLoadNode::LegalizeDType() { int index_lanes = 1; for (const auto& index : indices) { index_lanes *= index.dtype().lanes(); @@ -1079,7 +1079,7 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { node->buffer = std::move(buffer); node->indices = std::move(indices); node->span = std::move(span); - node->LegalizeDtype(); + node->LegalizeDType(); data_ = std::move(node); } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index fdb064076bca..3cf5ed2ecf7c 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -123,7 +123,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { return std::move(node); } else { auto writer = modified.CopyOnWrite(); - writer->LegalizeDtype(); + writer->LegalizeDType(); return std::move(modified); } } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 8acfe6f3ede4..b415b735b7da 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1397,7 +1397,7 @@ class VectorTypeRewriter : public StmtExprMutator { return std::move(node); } else { auto writer = modified.CopyOnWrite(); - writer->LegalizeDtype(); + writer->LegalizeDType(); return std::move(modified); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index a3437d88e3d3..feb396569ff9 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -385,7 +385,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices)) { auto writer = load.CopyOnWrite(); writer->indices = indices; - writer->LegalizeDtype(); + writer->LegalizeDType(); } return std::move(load); From 619beb528cf57640d69798b0cb4f82dcd4c6de93 Mon Sep 17 00:00:00 2001 From: adstraw Date: Thu, 3 Mar 2022 09:53:31 -0800 Subject: [PATCH 169/177] fix format and lint errors --- python/tvm/contrib/hexagon/build.py | 4 +- python/tvm/contrib/hexagon/session.py | 1 + python/tvm/topi/utils.py | 4 +- .../test_hexagon/test_cache_read_write.py | 6 +-- .../contrib/test_hexagon/test_launcher.py | 24 +++++----- tests/python/frontend/tflite/test_forward.py | 1 + tests/python/relay/aot/aot_test_utils.py | 4 +- .../test_topi_group_conv2d_transpose.py | 45 ++++++++++++++----- ..._meta_schedule_postproc_verify_gpu_code.py | 2 + 9 files changed, 59 insertions(+), 32 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 52dcd3049566..17a95b0ce333 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -278,7 +278,9 @@ def _copy_to_remote( self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path] ): """Abstract method implementation. See description in HexagonLauncherRPC.""" - subprocess.check_call(self._adb_device_sub_cmd + ["push", str(local_path), str(remote_path)]) + subprocess.check_call( + self._adb_device_sub_cmd + ["push", str(local_path), str(remote_path)] + ) def _create_remote_directory(self, remote_path: Union[str, pathlib.Path]): """Abstract method implementation. See description in HexagonLauncherRPC.""" diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 724c3e227faa..2d3f075daa05 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -22,6 +22,7 @@ from typing import Union from tvm import rpc as _rpc + class Session: """Hexagon Device Session diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 9934a045875a..5f43b068e25f 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -310,10 +310,10 @@ def unravel_index(idx, shape): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod indices = [] - for i,dim in enumerate(reversed(shape)): + for i, dim in enumerate(reversed(shape)): if dim == 0: indices.append(0) - elif i == len(shape)-1: + elif i == len(shape) - 1: # Assuming the index is in-bounds, the last coordinate is # already less than dim, and doesn't need the be remainder # mod dim. diff --git a/tests/python/contrib/test_hexagon/test_cache_read_write.py b/tests/python/contrib/test_hexagon/test_cache_read_write.py index 320e0ba8e279..8216d07adece 100644 --- a/tests/python/contrib/test_hexagon/test_cache_read_write.py +++ b/tests/python/contrib/test_hexagon/test_cache_read_write.py @@ -112,9 +112,9 @@ def test_cache_read_write(android_serial_number, tvm_tracker_host, tvm_tracker_p pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") rpc_info = { - "rpc_tracker_host" : tvm_tracker_host, - "rpc_tracker_port" : tvm_tracker_port, - "rpc_server_port" : 7070, + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, + "rpc_server_port": 7070, } launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) launcher.upload(dso_binary_path, dso_binary) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 34a4f4d69e47..7855d32d63b7 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -53,9 +53,9 @@ def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port): pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") rpc_info = { - "rpc_tracker_host" : tvm_tracker_host, - "rpc_tracker_port" : tvm_tracker_port, - "rpc_server_port" : 7070, + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, + "rpc_server_port": 7070, } launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) launcher.upload(dso_binary_path, dso_binary) @@ -97,9 +97,9 @@ def test_add_vtcm(android_serial_number, tvm_tracker_host, tvm_tracker_port): pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") rpc_info = { - "rpc_tracker_host" : tvm_tracker_host, - "rpc_tracker_port" : tvm_tracker_port, - "rpc_server_port" : 7070, + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, + "rpc_server_port": 7070, } launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) launcher.upload(dso_binary_path, dso_binary) @@ -149,9 +149,9 @@ def test_matmul(self, android_serial_number, tvm_tracker_host, tvm_tracker_port, pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") rpc_info = { - "rpc_tracker_host" : tvm_tracker_host, - "rpc_tracker_port" : tvm_tracker_port, - "rpc_server_port" : 7070, + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, + "rpc_server_port": 7070, } launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) launcher.upload(dso_binary_path, dso_binary) @@ -220,9 +220,9 @@ def test_graph_executor(android_serial_number, tvm_tracker_host, tvm_tracker_por pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") rpc_info = { - "rpc_tracker_host" : tvm_tracker_host, - "rpc_tracker_port" : tvm_tracker_port, - "rpc_server_port" : 7070, + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, + "rpc_server_port": 7070, } launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) launcher.upload(dso_binary_path, dso_binary) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 05b035e56239..599669e86d84 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3795,6 +3795,7 @@ def test_forward_prelu(): np.full((3), 0.2, dtype="float32"), ) + ####################################################################### # DepthToSpace # ------------ diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index deda4bc7cdfd..5ca25fd1b204 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -276,9 +276,7 @@ def subprocess_check_log_output(cmd, cwd, logfile): proc.wait() if proc.returncode != 0: - raise RuntimeError( - f"Subprocess failed: {cmd}\nstdout:\n{stdout}" - ) + raise RuntimeError(f"Subprocess failed: {cmd}\nstdout:\n{stdout}") # TODO: Move to linker script with list of symbols rather than coding into source diff --git a/tests/python/topi/python/test_topi_group_conv2d_transpose.py b/tests/python/topi/python/test_topi_group_conv2d_transpose.py index 3ba4550aa1e5..f55b906990fb 100644 --- a/tests/python/topi/python/test_topi_group_conv2d_transpose.py +++ b/tests/python/topi/python/test_topi_group_conv2d_transpose.py @@ -25,7 +25,10 @@ from tvm.topi.utils import get_const_tuple _group_conv2d_nchw_implement = { - "generic": (topi.nn.group_conv2d_transpose_nchw, topi.generic.schedule_group_conv2d_transpose_nchw), + "generic": ( + topi.nn.group_conv2d_transpose_nchw, + topi.generic.schedule_group_conv2d_transpose_nchw, + ), "cuda": (topi.cuda.conv2d_transpose_nchw, topi.cuda.schedule_conv2d_transpose_nchw), } @@ -124,17 +127,37 @@ def test_group_conv2d_transpose_nchw(): verify_group_conv2d_transpose_nchw(1, 4, (32, 32), 4, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0), 2) verify_group_conv2d_transpose_nchw(1, 9, (32, 32), 9, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0), 3) verify_group_conv2d_transpose_nchw(1, 4, (32, 32), 16, (5, 5), (2, 2), (1, 1, 1, 1), (0, 0), 4) - verify_group_conv2d_transpose_nchw(1, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 2) - verify_group_conv2d_transpose_nchw(1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 16) - verify_group_conv2d_transpose_nchw(1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (1, 0), 16) - verify_group_conv2d_transpose_nchw(1, 64, (64, 64), 64, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 64) - verify_group_conv2d_transpose_nchw(1, 128, (32, 32), 128, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 128) - verify_group_conv2d_transpose_nchw(1, 256, (16, 16), 256, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 256) + verify_group_conv2d_transpose_nchw( + 1, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 2 + ) + verify_group_conv2d_transpose_nchw( + 1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 16 + ) + verify_group_conv2d_transpose_nchw( + 1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (1, 0), 16 + ) + verify_group_conv2d_transpose_nchw( + 1, 64, (64, 64), 64, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 64 + ) + verify_group_conv2d_transpose_nchw( + 1, 128, (32, 32), 128, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 128 + ) + verify_group_conv2d_transpose_nchw( + 1, 256, (16, 16), 256, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 256 + ) verify_group_conv2d_transpose_nchw(1, 1, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0), 1) - verify_group_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1) - verify_group_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0), 1) - verify_group_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1) - verify_group_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0), 1) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0), 1 + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index bb734bfc9299..333cd949f8ea 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -430,6 +430,7 @@ def test_postproc_verify_gpu_3(): # threadIdx.x extent). assert not ctx.postprocs[0].apply(sch) + def test_postproc_verify_gpu_4(): mod = GmmCuda0 ctx = _create_context(mod, target=_target()) @@ -450,5 +451,6 @@ def test_postproc_verify_gpu_6(): sch = tir.Schedule(mod, debug_mask="all") assert not ctx.postprocs[0].apply(sch) + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From e930e51cf2588eb1fd70ec1abde9f0d2dc67b89e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Mar 2022 15:04:54 -0600 Subject: [PATCH 170/177] Disable vectorization of AllocateConst buffer in StorageRewrite. --- src/relay/transforms/fold_constant.cc | 10 +--------- src/tir/transforms/storage_rewrite.cc | 24 ++++++++++++++++++------ tests/python/relay/aot/test_crt_aot.py | 19 +++++++++++-------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index cebeeb3d73fe..a078cabda3f6 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -252,15 +252,7 @@ class ConstantFolder : public MixedModeMutator { // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) - // With fresh_build_ctx(transform::PassContext::Create()); - - // Disabling vectorization in Eval, as temporary workaround for - // testing PR-9727. Constant folding current produces incorrect - // results, looks to be an issue with vectorized access of - // AllocateConst buffers. - auto fresh_context = transform::PassContext::Create(); - fresh_context->config.Set("tir.disable_vectorize", Bool(true)); - With with_context(fresh_context); + With fresh_build_ctx(transform::PassContext::Create()); Map dict = (module_->attrs.defined()) ? module_->attrs->dict : Map(); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index b415b735b7da..8056916a9334 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -980,7 +980,8 @@ struct BufferVarInfo { kPrimFuncParam = (1 << 0), kPrimFuncBufferMap = (1 << 1), kAllocateNode = (1 << 2), - kLetNode = (1 << 3), + kAllocateConstNode = (1 << 3), + kLetNode = (1 << 4), }; // The tir::Var that represents this buffer. @@ -1122,7 +1123,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { void VisitStmt_(const AllocateConstNode* op) final { const Array& extents = op->extents; PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue(); - OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode); StmtExprVisitor::VisitStmt_(op); } @@ -1312,7 +1313,7 @@ class VectorTypeRewriter : public StmtExprMutator { VectorTypeRewriter(const std::unordered_map& info_map, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true) + bool rewrite_let_node = true, bool rewrite_allocate_const_node = true) : rewrite_indices_(rewrite_indices) { int rewrite_mask = 0; if (rewrite_params) { @@ -1327,6 +1328,9 @@ class VectorTypeRewriter : public StmtExprMutator { if (rewrite_let_node) { rewrite_mask |= BufferVarInfo::kLetNode; } + if (rewrite_allocate_const_node) { + rewrite_mask |= BufferVarInfo::kAllocateConstNode; + } // Rewrite any buffer variables whose preferred type isn't their current type. for (const auto& pair : info_map) { @@ -1576,12 +1580,14 @@ class VectorTypeRewriter : public StmtExprMutator { PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true) { + bool rewrite_let_node = true, + bool rewrite_allocate_const_node = true) { VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); checker(f->body); VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, - rewrite_allocate_node, rewrite_indices, rewrite_let_node); + rewrite_allocate_node, rewrite_indices, rewrite_let_node, + rewrite_allocate_const_node); PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); rewriter.Finalize(&f); @@ -1595,7 +1601,13 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true); + // Parameters may not be rewritten, but internal allocations may. + // Vectorization of AllocateConst is currently disabled, as it has + // indexing issues for types that include padding (e.g. int8x3 + // padded out to 32 bits) would require either rewriting + // AllocateConst::data, or would require the code generators to + // handle vectorized constants. + return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 0147b8cf755a..6ad6f28cfe02 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -758,24 +758,27 @@ class Model(tf.Module): def tf_function(self, x): # Use tf.nn API to create the model tf_strides = [1, strides[0], strides[1], 1] + filter_shape = [kernel_shape[0], kernel_shape[1], 3, 3] + filter1 = tf.constant( + np.arange(np.prod(filter_shape)).reshape(filter_shape), + dtype=tf.float32, + ) op = tf.nn.conv2d( x, - filters=tf.constant( - np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), - dtype=tf.float32, - ), + filters=filter1, strides=tf_strides, padding=padding, dilations=dilation, ) op = tf.nn.relu(op) # Second convolution + filter2 = tf.constant( + 1000 + np.arange(np.prod(filter_shape)).reshape(filter_shape), + dtype=tf.float32, + ) op2 = tf.nn.conv2d( x, - filters=tf.constant( - np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 3, 3)), - dtype=tf.float32, - ), + filters=filter2, strides=strides, padding=padding, data_format="NHWC", From 4971a0246a5d41d6f32fc2e9c06b699701981959 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Mar 2022 16:49:42 -0600 Subject: [PATCH 171/177] Pass buffer_map through to the PrimFunc in cmsisnn --- src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index f06ed254be99..46eacec13b99 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -107,7 +107,7 @@ class RelayToTIRVisitor : public MixedModeMutator { {context_buffer_size}, tir::const_true(), body); } - tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), + tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, Map(), DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } From e6e149bb68c16c792c114581d2c18804c93c7f10 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 4 Mar 2022 15:59:51 +0900 Subject: [PATCH 172/177] try disabling problematic winograd test case --- tests/python/topi/python/test_topi_conv2d_winograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_conv2d_winograd.py b/tests/python/topi/python/test_topi_conv2d_winograd.py index 82368f118f32..b843fc7b4e39 100644 --- a/tests/python/topi/python/test_topi_conv2d_winograd.py +++ b/tests/python/topi/python/test_topi_conv2d_winograd.py @@ -233,7 +233,7 @@ def test_conv2d_nhwc(): verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, 1) # more shapes - verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, 1) + # verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, 1) verify_conv2d_nhwc(1, 1, 1, 1, 3, 1, 1) verify_conv2d_nhwc(3, 3, 3, 3, 3, 1, 1) verify_conv2d_nhwc(2, 13, 71, 59, 3, 1, 1) From 054af2eda7199c35e371a34ce09e0504fc21bfd1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 4 Mar 2022 17:56:36 +0900 Subject: [PATCH 173/177] try different way of buffer mapping in storage_rewrite --- src/tir/transforms/storage_rewrite.cc | 10 +++++----- tests/python/topi/python/test_topi_conv2d_winograd.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 8056916a9334..6e8e824c5fa2 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -430,11 +430,11 @@ class StoragePlanRewriter : public StmtExprMutator { return it->second; } - auto writer = buf.CopyOnWrite(); - writer->data = new_backing_array; - - buffer_remap_[key] = buf; - return buf; + Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides, + buf->elem_offset, new_backing_array->name_hint, buf->data_alignment, + buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); + buffer_remap_[key] = remapped; + return remapped; } Stmt VisitStmt_(const BufferStoreNode* op) final { diff --git a/tests/python/topi/python/test_topi_conv2d_winograd.py b/tests/python/topi/python/test_topi_conv2d_winograd.py index b843fc7b4e39..82368f118f32 100644 --- a/tests/python/topi/python/test_topi_conv2d_winograd.py +++ b/tests/python/topi/python/test_topi_conv2d_winograd.py @@ -233,7 +233,7 @@ def test_conv2d_nhwc(): verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, 1) # more shapes - # verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, 1) + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, 1) verify_conv2d_nhwc(1, 1, 1, 1, 3, 1, 1) verify_conv2d_nhwc(3, 3, 3, 3, 3, 1, 1) verify_conv2d_nhwc(2, 13, 71, 59, 3, 1, 1) From c1d5fc2a142fc77120568f5206b742b4a1d85719 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Mar 2022 11:02:21 -0600 Subject: [PATCH 174/177] Removed unnecessary ramp node in ir_builder. --- python/tvm/tir/ir_builder.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 928e5007df61..99235df041ce 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -90,12 +90,6 @@ def _normalize_index(self, index): except TypeError: index = [index] - t = DataType(self._content_type) - if t.lanes > 1: - base = index[-1] * t.lanes - stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index[-1] = _expr.Ramp(base, stride, t.lanes) - index = [x.var if isinstance(x, _expr.IterVar) else x for x in index] # Workaround to support previous behavior of ir_builder From d1a5123ead5d239c7e4b99900d05af726cae5421 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 4 Mar 2022 11:39:59 -0600 Subject: [PATCH 175/177] Fix lint error. --- python/tvm/tir/ir_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 99235df041ce..334902b53229 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" import tvm from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, DataType, convert, const +from tvm.runtime import ObjectGeneric, convert, const from tvm.ir import container as _container from . import stmt as _stmt From 5dca3ff6b399e2fbaf6229d413fb1087930be765 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 6 Mar 2022 14:13:38 -0600 Subject: [PATCH 176/177] Updated LLVM codegen for buffer indexing. TVM data arrays are always densely packed. If the LLVM type corresponding to a vectorized TVM datatype contains padding for alignment, the array location should be computed based on the primitive element type. --- src/target/llvm/codegen_cpu.cc | 6 +- src/target/llvm/codegen_hexagon.cc | 6 +- src/target/llvm/codegen_llvm.cc | 94 +++++++++++++++++++++--------- src/target/llvm/codegen_llvm.h | 5 +- 4 files changed, 77 insertions(+), 34 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6d9d98072ee6..ded346eaaf36 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -810,11 +810,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Value* arg_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + TypedPointer arg_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + TypedPointer ret_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 496c73afa4f5..32587030ba17 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -319,11 +319,13 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + TypedPointer arg_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + TypedPointer ret_tcode = + CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 2cb7d612b5d1..c78a89f23d27 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -788,18 +788,35 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { return ptr; } -CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, - llvm::Value* index) { - llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); - ICHECK(btype != nullptr); - llvm::Type* llvm_type = DTypeToLLVMType(t); - ICHECK(llvm_type) << "Could not make LLVM type to represent TVM type " << t; - llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace()); - if (btype != ttype) { - buffer = builder_->CreatePointerCast(buffer, ttype); +CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, + DataType buffer_element_dtype, + llvm::Value* index, DataType value_dtype) { + llvm::PointerType* buffer_ptr_type = llvm::dyn_cast(buffer_ptr->getType()); + ICHECK(buffer_ptr_type != nullptr); + auto address_space = buffer_ptr_type->getAddressSpace(); + + llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype); + llvm::PointerType* element_ptr_type = + DTypeToLLVMType(buffer_element_dtype)->getPointerTo(address_space); + llvm::Type* value_type = DTypeToLLVMType(value_dtype); + llvm::PointerType* value_ptr_type = value_type->getPointerTo(address_space); + + ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer"; + + if (buffer_ptr_type != element_ptr_type) { + buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type); } - llvm::Value* ptr = builder_->CreateInBoundsGEP(llvm_type, buffer, index); - return TypedPointer(llvm_type, ptr); + ICHECK(!HasAlignmentPadding(buffer_element_dtype)) + << "DType " << buffer_element_dtype + << " has padding for alignment. TVM data arrays are expected to be densely packed, with no " + "padding for alignment."; + llvm::Value* value_ptr = builder_->CreateInBoundsGEP(element_type, buffer_ptr, index); + + if (element_ptr_type != value_ptr_type) { + value_ptr = builder_->CreatePointerCast(value_ptr, value_ptr_type); + } + + return TypedPointer(value_type, value_ptr); } llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { @@ -991,8 +1008,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (const RampNode* r = index.as()) { index = r->base; } - TypedPointer buffer_ptr = - CreateBufferPtr(load->buffer->dtype, MakeValue(load->buffer->data), MakeValue(index)); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype, + MakeValue(index), load->dtype); unsigned addrspace = llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); @@ -1248,6 +1265,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { return NULL; } +bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { + const llvm::DataLayout& data_layout = module_->getDataLayout(); + int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype)); + int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of())); + return bytes != bytes_scalar*dtype.lanes(); +} + llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; @@ -1257,13 +1281,20 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { PrimExpr buffer_index = op->indices[0]; bool is_volatile = volatile_buf_.count(buffer_var.get()); - llvm::Value* buffer = MakeValue(buffer_var); - llvm::Value* index = MakeValue(buffer_index); if (t.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); + + TypedPointer buffer_ptr; + if (HasAlignmentPadding(buffer_element_dtype)) { + buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), + MakeValue(buffer_element_dtype.lanes() * buffer_index), t); + } else { + buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + MakeValue(buffer_index), t); + } + #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1283,13 +1314,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = - CreateBufferPtr(buffer_element_dtype, buffer, MakeValue(ramp->base)); - unsigned addrspace = - llvm::dyn_cast(buffer->getType())->getAddressSpace(); - buffer_ptr.type = DTypeToLLVMType(t); - buffer_ptr.addr = - builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, + MakeValue(ramp->base), t); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1308,7 +1334,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { int basic_align = t.bits() / 8; llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); + TypedPointer buffer_ptr = + CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, index, t.element_of()); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(basic_align), is_volatile); @@ -1399,13 +1426,21 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { bool is_volatile = volatile_buf_.count(buffer_var.get()); llvm::Value* buffer = MakeValue(buffer_var); - llvm::Value* index = MakeValue(buffer_index); llvm::Value* value = MakeValue(op->value); if (value_dtype.lanes() == buffer_element_dtype.lanes()) { int alignment, native_bits; GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits); - TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); + + TypedPointer buffer_ptr; + if (HasAlignmentPadding(buffer_element_dtype)) { + buffer_ptr = + CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), + MakeValue(buffer_element_dtype.lanes() * buffer_index), value_dtype); + } else { + buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + MakeValue(buffer_index), value_dtype); + } #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); @@ -1423,8 +1458,8 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes()); // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = - CreateBufferPtr(buffer_element_dtype, buffer, MakeValue(ramp->base)); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + MakeValue(ramp->base), value_dtype); unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); buffer_ptr.type = DTypeToLLVMType(value_dtype); @@ -1446,7 +1481,8 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { // scalarized store. int basic_align = value_dtype.bits() / 8; auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(buffer_element_dtype, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, + index, value_dtype.element_of()); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 2b40f761a3b5..e8cbe7ae445f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -321,6 +321,8 @@ class CodeGenLLVM : public ExprFunctor, // Get alignment given index. void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits); + // Returns whether the LLVM type has padding for alignment + bool HasAlignmentPadding(DataType dtype); // Get constant string llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f @@ -340,7 +342,8 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - TypedPointer CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); + TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, + llvm::Value* index, DataType value_dtype); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); From 084c21c57ccbdaf5000db8aa518fa1bebe19f4d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 6 Mar 2022 14:54:09 -0600 Subject: [PATCH 177/177] Resolve lint error. --- src/target/llvm/codegen_llvm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index c78a89f23d27..cc2e495f6e37 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1269,7 +1269,7 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { const llvm::DataLayout& data_layout = module_->getDataLayout(); int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype)); int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of())); - return bytes != bytes_scalar*dtype.lanes(); + return bytes != bytes_scalar * dtype.lanes(); } llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {