From 2dbb9020220430f1616e9f49be00b6a81655bb0a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Oct 2021 14:27:46 -0500 Subject: [PATCH 1/9] [TIR] Changed AllocateNode::extents to a single AllocateNode::extent This commit has the primary goal of the PR, replacing the AllocateNode::extents variable, which represented the N-d shape of the tensor, with a 1-d AllocateNode::extent variable representing the amount of memory to be allocated. This is part of a larger push to support more flexible memory layouts across multiple device types, and to introduce a split between the physical layout in memory and the logical layout in a tensor or a compute definition. --- include/tvm/tir/stmt.h | 20 ++++++++--------- python/tvm/tir/stmt.py | 8 +++---- python/tvm/topi/cuda/sparse.py | 1 - src/tir/ir/stmt.cc | 39 ++++++++++++---------------------- 4 files changed, 28 insertions(+), 40 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5cd860b8e929..480ceebbf315 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -515,8 +515,8 @@ class AllocateNode : public StmtNode { Var buffer_var; /*! \brief The type of the buffer. */ DataType dtype; - /*! \brief The extents of the buffer. */ - Array extents; + /*! \brief The extent of the buffer. */ + PrimExpr extent; /*! \brief Only allocate buffer when condition is satisfied. */ PrimExpr condition; /*! \brief The body to be executed. */ @@ -532,7 +532,7 @@ class AllocateNode : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); v->Visit("dtype", &dtype); - v->Visit("extents", &extents); + v->Visit("extent", &extent); v->Visit("condition", &condition); v->Visit("body", &body); v->Visit("annotations", &annotations); @@ -541,14 +541,14 @@ class AllocateNode : public StmtNode { bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && - equal(extents, other->extents) && equal(condition, other->condition) && + equal(extent, other->extent) && equal(condition, other->condition) && equal(body, other->body) && equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(buffer_var); hash_reduce(dtype); - hash_reduce(extents); + hash_reduce(extent); hash_reduce(condition); hash_reduce(body); hash_reduce(annotations); @@ -559,14 +559,14 @@ class AllocateNode : public StmtNode { * Otherwise return 0. * \return The result. */ - int32_t constant_allocation_size() const { return constant_allocation_size(extents); } + int32_t constant_allocation_size() const { return constant_allocation_size(extent); } /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. - * \param extents The extents of the buffer. + * \param extent The extent of the buffer. * \return The result. */ - TVM_DLL static int32_t constant_allocation_size(const Array& extents); + TVM_DLL static int32_t constant_allocation_size(const PrimExpr& extent); static constexpr const char* _type_key = "tir.Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); @@ -578,8 +578,8 @@ class AllocateNode : public StmtNode { */ class Allocate : public Stmt { public: - TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations = Map(), + TVM_DLL Allocate(Var buffer_var, DataType dtype, PrimExpr extent, PrimExpr condition, Stmt body, + Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index de200d5eabdd..ea6b3b6a7945 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -309,8 +309,8 @@ class Allocate(Stmt): dtype : str The data type of the buffer. - extents : list of Expr - The extents of the allocate + extent : Expr + The number of elements to allocate condition : PrimExpr The condition. @@ -325,14 +325,14 @@ class Allocate(Stmt): The location of this itervar in the source code. """ - def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None): + def __init__(self, buffer_var, dtype, extent, condition, body, annotations=None, span=None): if annotations is None: annotations = dict() self.__init_handle_by_constructor__( _ffi_api.Allocate, # type: ignore buffer_var, dtype, - extents, + extent, condition, body, annotations, diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 32f20a15016e..70baef923bb3 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -149,7 +149,6 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) m = data.shape[1] nb = w_indptr.shape[0] - 1 - nnzb = w_data.shape[0] # treat csr like block size 1 bsr if len(w_data.shape) == 1: bs_n = 1 diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0d42c20c2822..bf37d34ddaef 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -332,18 +332,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Allocate -Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { +Allocate::Allocate(Var buffer_var, DataType dtype, PrimExpr extent, PrimExpr condition, Stmt body, + Map annotations, Span span) { CHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" << buffer_var->type_annotation << "). The data type should be an element of the pointer type."; - for (size_t i = 0; i < extents.size(); ++i) { - ICHECK(extents[i].defined()); - ICHECK(extents[i].dtype().is_scalar()); - } + ICHECK(extent.defined()); + ICHECK(extent.dtype().is_scalar()); ICHECK(body.defined()); ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); @@ -351,7 +349,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; - node->extents = std::move(extents); + node->extent = std::move(extent); node->condition = std::move(condition); node->body = std::move(body); node->annotations = std::move(annotations); @@ -359,25 +357,18 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim data_ = std::move(node); } -int32_t AllocateNode::constant_allocation_size(const Array& extents) { - int64_t result = 1; - for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode* int_size = extents[i].as()) { - result *= int_size->value; - if (result > std::numeric_limits::max()) { - return 0; - } - } else { - return 0; - } +int32_t AllocateNode::constant_allocation_size(const PrimExpr& extent) { + if (const IntImmNode* int_size = extent.as()) { + return int_size->value; + } else { + return 0; } - return static_cast(result); } TVM_REGISTER_GLOBAL("tir.Allocate") - .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, + .set_body_typed([](Var buffer_var, DataType type, PrimExpr extent, PrimExpr condition, Stmt body, Map annotations, Span span) { - return Allocate(buffer_var, type, extents, condition, body, annotations, span); + return Allocate(buffer_var, type, extent, condition, body, annotations, span); }); TVM_REGISTER_NODE_TYPE(AllocateNode); @@ -389,10 +380,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ICHECK(ptr_type) << "The provided variable is not of pointer type"; p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - p->stream << " * "; - p->Print(op->extents[i]); - } + p->stream << " * "; + p->Print(op->extent); p->stream << "], storage_scope = " << ptr_type->storage_scope; if (!is_one(op->condition)) { p->stream << " if "; From dca38e2fab5e1f694450753479dd03092d91bb28 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Oct 2021 14:29:54 -0500 Subject: [PATCH 2/9] C++ changes needed to compile with AllocateNode::extent After updating to a 1-d index in AllocateNode::extent, this commit adds all the secondary changes needed elsewhere in the codebase to be compatible with the 1-d extent. --- include/tvm/tir/buffer.h | 7 ++++ src/printer/tir_text_printer.cc | 5 +-- src/printer/tvmscript_printer.cc | 4 +- src/te/operation/cross_thread_reduction.cc | 5 +-- src/tir/analysis/calculate_workspace.cc | 10 +---- src/tir/ir/buffer.cc | 15 +++++++ src/tir/ir/stmt_functor.cc | 8 ++-- src/tir/transforms/bf16_legalize.cc | 2 +- src/tir/transforms/bound_checker.cc | 24 +++-------- src/tir/transforms/flatten_buffer.cc | 2 +- src/tir/transforms/inject_double_buffer.cc | 11 ++--- src/tir/transforms/inject_virtual_thread.cc | 42 ++++++------------- src/tir/transforms/ir_utils.cc | 2 +- src/tir/transforms/lift_attr_scope.cc | 2 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 16 +++---- src/tir/transforms/lower_tvm_builtin.cc | 5 +-- src/tir/transforms/lower_warp_memory.cc | 2 +- ...merge_dynamic_shared_memory_allocations.cc | 7 ++-- src/tir/transforms/remove_no_op.cc | 2 +- src/tir/transforms/split_host_device.cc | 8 +--- src/tir/transforms/storage_flatten.cc | 12 +++--- src/tir/transforms/storage_rewrite.cc | 26 ++++-------- .../update_pointer_storage_scope.cc | 2 +- src/tir/transforms/vectorize_loop.cc | 16 +++---- tests/cpp/ir_functor_test.cc | 26 ++++++------ 26 files changed, 109 insertions(+), 154 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index f04209d0b061..c84eda466570 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -129,6 +129,13 @@ class BufferNode : public Object { */ PrimExpr ElemOffset(Array index) const; + /*! \brief Return number of elements in the buffer + * + * If the size of the buffer isn't constant, or if the size would + * overflow a 32-bit signed integer, return 0. + */ + int32_t NumElements() const; + static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fa132f079793..4ef4da8c7614 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -449,9 +449,8 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); - doc << "allocate(" << Print(op->buffer_var) << ", "; - doc << PrintDType(op->dtype) << ", "; - doc << Print(op->extents) << "), storage_scope = " << scope; + doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extent) << "), storage_scope = " << scope; if (!op->annotations.empty()) { std::vector attr_docs; for (const auto& it : op->annotations) { diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fa74e56f491c..124fa0400557 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -764,7 +764,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto storage_scope = GetPtrStorageScope(op->buffer_var); if (current_num_ != num_child_ - 1) { - doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " + doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extent) << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); if (!is_one(op->condition)) { doc << ", " << Print(op->condition); @@ -777,7 +777,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { doc << ") as " << Print(op->buffer_var) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents) + doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extent) << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); if (!is_one(op->condition)) { doc << ", " << Print(op->condition); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 2ed5fd4029a2..b3199405916f 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -224,10 +224,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, 1, const_true(), body); if (!normal_red.empty()) { - body = - Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, 1, const_true(), body); } } body = Substitute(body, value_map); diff --git a/src/tir/analysis/calculate_workspace.cc b/src/tir/analysis/calculate_workspace.cc index 49ddaf613c6d..739fb9b99d52 100644 --- a/src/tir/analysis/calculate_workspace.cc +++ b/src/tir/analysis/calculate_workspace.cc @@ -55,15 +55,7 @@ size_t WorkspaceCalculator::GetByteAlignedSize(size_t non_aligned_size) { size_t WorkspaceCalculator::CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); - size_t num_elements = 1; - for (const auto& ext : op->extents) { - if (ext->IsInstance()) { - num_elements *= Downcast(ext)->value; - } else { - // We cant statically calculate workspace for dynamic shapes - num_elements = 0; - } - } + size_t num_elements = op->constant_allocation_size(); return GetByteAlignedSize(num_elements * element_size_bytes); } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 24aacc3c04f7..41fe556e5688 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -291,6 +291,21 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp } } +int32_t BufferNode::NumElements() const { + int64_t result = 1; + for (const PrimExpr& dim : shape) { + if (const IntImmNode* int_size = dim.as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; + } + } else { + return 0; + } + } + return static_cast(result); +} + PrimExpr Buffer::vload(Array begin, DataType dtype) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..3b3a9f32ccbe 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -53,7 +53,7 @@ void StmtVisitor::VisitStmt_(const WhileNode* op) { } void StmtVisitor::VisitStmt_(const AllocateNode* op) { - VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); + this->VisitExpr(op->extent); this->VisitStmt(op->body); this->VisitExpr(op->condition); } @@ -304,15 +304,15 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { - Array extents = Internal::Mutate(this, op->extents); + PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { + if (extent.same_as(op->extent) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->extents = std::move(extents); + n->extent = std::move(extent); n->body = std::move(body); n->condition = std::move(condition); return Stmt(n); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 76845cbebd2a..b8f83db977ed 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -220,7 +220,7 @@ class BF16LowerRewriter : public StmtExprMutator { DataType dtype = DataType::UInt(16, op->dtype.lanes()); Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); var_remap_[op->buffer_var] = buffer_var; - return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); + return VisitStmt(Allocate(buffer_var, dtype, op->extent, op->condition, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 3b6af0644fc9..7b9be15a83ab 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -61,7 +61,7 @@ class BoundChecker : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { // If the shape was updated we should update the hashtable. if (UpdateIsNeeded(op->buffer_var)) { - Update(op->buffer_var, op->extents, op->dtype); + Update(op->buffer_var, op->extent, op->dtype); } return StmtExprMutator::VisitStmt_(op); } @@ -108,28 +108,14 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { + void Update(const Var& buffer_var, const PrimExpr new_extent, const DataType& type) { // Sanity check at first. - if (!new_shape.size()) { + if (!new_extent.defined() || !new_extent.dtype().is_scalar() || is_negative_const(new_extent)) { return; } - for (size_t i = 0; i < new_shape.size(); ++i) { - if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() || - is_negative_const(new_shape[i])) { - return; - } - } - - // Scalarize the shape. - PrimExpr shape = - Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0])); - for (size_t i = 1; i < new_shape.size(); ++i) { - // Cast to unsigned to avoid integer overlow at frist. - shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()), - Cast(DataType::UInt(64), new_shape[i]))); - } - mem_to_shape_[buffer_var.get()] = shape; + // Define the extent including lanes. + mem_to_shape_[buffer_var.get()] = Mul(make_const(DataType::UInt(64), type.lanes()), new_extent); } bool IndexIsValid(const PrimExpr& index) const { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index e0ab95a537e7..a2d31fada422 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -134,7 +134,7 @@ class BufferFlattener : public StmtExprMutator { static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); - body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); + body = Allocate(buffer->data, buffer->dtype, area, const_true(), std::move(body)); return body; } diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 0b45bde28dfe..0137b9133e87 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -107,19 +107,14 @@ class DoubleBufferInjector : public StmtExprMutator { auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { it->second.scope = GetPtrStorageScope(op->buffer_var); - it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); + it->second.stride = op->extent * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - Array new_extents{make_const(op->extents[0].dtype(), 2)}; - for (PrimExpr e : op->extents) { - new_extents.push_back(e); - } + PrimExpr new_extent = mul(make_const(op->extent.dtype(), 2), op->extent); ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; alloc_nest.emplace_back( - Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); + Allocate(op->buffer_var, op->dtype, new_extent, op->condition, Evaluate(0))); return op->body; } else { return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 4964bec0334e..1fa2265d75db 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -124,9 +124,7 @@ class VarTouchedAnalysis : public StmtVisitor { } void VisitStmt_(const AllocateNode* op) final { ExprTouched tc(touched_var_, false); - for (size_t i = 0; i < op->extents.size(); ++i) { - tc(op->extents[i]); - } + tc(op->extent); tc.VisitExpr(op->condition); Record(op->buffer_var.get(), tc); this->VisitStmt(op->body); @@ -359,44 +357,30 @@ class VTInjector : public StmtExprMutator { return InjectVTLoop(GetRef(op), true); } - bool changed = false; - Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); - if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); - } - if (!new_ext.same_as(op->extents[i])) changed = true; - extents.push_back(new_ext); + PrimExpr extent = this->VisitExpr(op->extent); + if (visit_touched_var_ && !vt_loop_injected_) { + return InjectVTLoop(GetRef(op), true); } + visit_touched_var_ = false; Stmt body; // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents) * - op->dtype.lanes(); - Array other; - other.push_back(make_const(op->extents[0].dtype(), num_threads_)); - for (PrimExpr e : extents) { - other.push_back(e); - } - extents = other; - changed = true; + PrimExpr stride = mul(op->extent, op->dtype.lanes()); + extent = mul(extent, num_threads_); // mark this buffer get touched. alloc_remap_[op->buffer_var.get()] = stride; - // Mutate the body. - body = this->VisitStmt(op->body); - } else { - // Mutate the body. - body = this->VisitStmt(op->body); } - if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { + + // Mutate the body. + body = this->VisitStmt(op->body); + + if (extent.same_as(op->extent) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { - return Allocate(op->buffer_var, op->dtype, extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extent, condition, body); } } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 262906ade2e8..ebd39f6a1018 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -167,7 +167,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return Allocate(new_var, op->dtype, op->extents, op->condition, op->body); + return Allocate(new_var, op->dtype, op->extent, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 40d152b3b3b6..6a7fa2319b48 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -55,7 +55,7 @@ class AttrScopeLifter : public StmtMutator { // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); - return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); + return Allocate(op->buffer_var, op->dtype, op->extent, op->condition, body); } else { return stmt; } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 21f1b18d523b..c971c1a863a1 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -96,7 +96,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(allocate); allocate = stmt.as(); - return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition, + return Allocate(new_buffer_var, new_allocate_type, allocate->extent, allocate->condition, allocate->body); } else { return StmtExprMutator::VisitStmt_(allocate); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 6f7c09cdcf2d..155aabb342e1 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop // use volatile access to shared buffer. body = AttrStmt(remapped, attr::volatile_scope, 1, body); } - return Allocate(remapped, op->dtype, op->extents, op->condition, body); + return Allocate(remapped, op->dtype, op->extent, op->condition, body); } return StmtExprMutator::VisitStmt_(op); } @@ -98,10 +98,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extent, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extent, repl->condition, op->body); new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; @@ -256,7 +256,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Uses a local variable to store the shuffled data. // Later on, this allocation will be properly attached to this statement. Var var("t" + std::to_string(idx), ptr_type); - Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0)); + Stmt s = Allocate(var, types[idx], PrimExpr(1), pred, Evaluate(0)); local_vars.push_back(s); } @@ -340,8 +340,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var var = shared_bufs[i]; load_remap_[buffers[i]] = Load(types[i], var, index, pred); store_remap_[buffers[i]] = var; - Array extents{PrimExpr(1)}; - auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); + PrimExpr extent(1); + auto node = Allocate(var, types[i], extent, pred, Evaluate(0)); alloc_remap_[buffers[i]] = node; warp_allocs_.insert(node.get()); } @@ -381,7 +381,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); alloc_remap_[buffers[idx]] = Allocate(shared_bufs[idx], types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); + mul(PrimExpr(group_extent), PrimExpr(reduce_extent)), pred, Evaluate(0)); store_remap_[buffers[idx]] = shared_bufs[idx]; } } @@ -391,7 +391,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + body = Allocate(repl->buffer_var, repl->dtype, repl->extent, repl->condition, body); new_storage_scopes_[repl->buffer_var.get()] = "local"; } } diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 062d67eef165..910d532b1921 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -128,10 +128,7 @@ class BuiltinLower : public StmtExprMutator { } } } - PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); - for (size_t i = 0; i < op->extents.size(); ++i) { - total_bytes = total_bytes * op->extents[i]; - } + PrimExpr total_bytes = make_const(op->extent.dtype(), nbytes) * op->extent; ICHECK(device_type_.defined()) << "Unknown device type in current IR"; ICHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 30ec148c37dd..92f061514a6c 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -227,7 +227,7 @@ class WarpAccessRewriter : protected StmtExprMutator { warp_group_ = (alloc_size + (factor - 1)) / factor; alloc_size = warp_group_ * factor; - return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, + return Allocate(op->buffer_var, op->dtype, make_const(DataType::Int(32), alloc_size / width_), op->condition, this->VisitStmt(op->body)); } diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index e8865b260dc1..22cb041425ff 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -68,14 +68,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { align = std::max(align, alloc->dtype.bytes()); } for (const auto& alloc : dyn_shmem_allocs_) { - ICHECK_EQ(alloc->extents.size(), 1); buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; - merged_alloc_size_ += alloc->extents[0] * align; + merged_alloc_size_ += alloc->extent * align; } allocated = true; - auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, - const_true(), StmtExprMutator::VisitStmt(op->body)); + auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), merged_alloc_size_, const_true(), + StmtExprMutator::VisitStmt(op->body)); return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); } return StmtMutator::VisitStmt_(op); diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index aae1749b27db..6ee05b336344 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -81,7 +81,7 @@ class NoOpRemover : public StmtMutator { Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); - return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; + return is_no_op(op->body) ? MakeEvaluate(op->extent) : stmt; } Stmt VisitStmt_(const ProducerRealizeNode* op) final { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 795ae9d6a73a..2509499b156c 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -95,12 +95,8 @@ class VarUseDefAnalysis : public StmtExprMutator { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); - dyn_shmem_size_ = op->extents[0]; - for (size_t i = 1; i < op->extents.size(); ++i) { - dyn_shmem_size_ *= op->extents[i]; - } - dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + ICHECK(op->extent.defined()); + dyn_shmem_size_ = op->extent * op->dtype.bytes(); use_dyn_shmem_ = true; } return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6a3ce596c2fe..67436ed833d6 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1134,7 +1134,7 @@ class StorageFlattener : public StmtExprMutator { // use small alignment for small arrays auto dtype = op->buffer->dtype; - int32_t const_size = AllocateNode::constant_allocation_size(shape); + int32_t const_size = op->buffer->NumElements(); int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); @@ -1163,14 +1163,12 @@ class StorageFlattener : public StmtExprMutator { if (strides.size() != 0) { int first_dim = 0; ret = Allocate(e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, + e.buffer->strides[first_dim] * e.buffer->shape[first_dim], make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } else { - shape = e.buffer->shape; - if (shape.size() == 0) { - shape.push_back(make_const(DataType::Int(32), 1)); - } - ret = Allocate(e.buffer->data, storage_type, shape, + PrimExpr extent = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), e.buffer->shape); + ret = Allocate(e.buffer->data, storage_type, extent, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 409b7c262954..9ce0d94cb6b2 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -551,10 +551,8 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), e->allocs[0]->extents); - e->new_alloc = - Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); + e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extent, + e->allocs[0]->condition, Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -565,8 +563,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents); + PrimExpr sz = op->extent; auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { @@ -594,8 +591,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = combo_size + make_const(DataType::Int(32), 1); } combo_size = analyzer_.Simplify(combo_size); - e->new_alloc = - Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); + e->new_alloc = Allocate(e->alloc_var, alloc_type, combo_size, const_true(), Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -636,8 +632,8 @@ class StoragePlanRewriter : public StmtExprMutator { } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); PrimExpr alloc_size = - make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); - e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0)); + make_const(e->allocs[0]->extent.dtype(), (total_bits + type_bits - 1) / type_bits); + e->new_alloc = Allocate(e->alloc_var, e->elem_type, alloc_size, const_true(), Evaluate(0)); if (info.defined()) { ICHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); @@ -1025,9 +1021,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitStmt_(const AllocateNode* op) final { - const Array& extents = op->extents; - PrimExpr extent = extents[extents.size() - 1]; - OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + OnArrayDeclaration(op->buffer_var, op->dtype, op->extent, BufferVarInfo::kAllocateNode); StmtExprVisitor::VisitStmt_(op); } @@ -1342,10 +1336,8 @@ class VectorTypeRewriter : public StmtExprMutator { int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - Array extents = op->extents; - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); + PrimExpr extent = op->extent / make_const(op->extent.dtype(), factor); + return Allocate(new_buffer_var, info.new_element_dtype, extent, op->condition, op->body); } /* Update the parameters and all remaining variable references diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 4143577a0b17..b37b67019593 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -66,7 +66,7 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + return Allocate(remapped, op->dtype, op->extent, StmtExprMutator::VisitExpr(op->condition), StmtExprMutator::VisitStmt(op->body)); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index cd2d230f5775..8fca62308be2 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -434,21 +434,17 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); } - Array extents; - for (size_t i = 0; i < op->extents.size(); i++) { - PrimExpr new_ext = this->VisitExpr(op->extents[i]); - if (new_ext.dtype().is_vector()) { - LOG(WARNING) << "Cannot handle vector extent in alloc "; - return Scalarize(GetRef(op)); - } - extents.push_back(new_ext); + PrimExpr extent = this->VisitExpr(op->extent); + if (extent.dtype().is_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc "; + return Scalarize(GetRef(op)); } // place the vector lanes in least significant dimension. - extents.push_back(var_lanes_); + extent *= var_lanes_; // rewrite access to buffer internally. Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); - return Allocate(op->buffer_var, op->dtype, extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extent, condition, body); } // scalarize the statment diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 97809b0e1398..d065b65c9cd1 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -169,7 +169,7 @@ TEST(IRF, StmtVisitor) { Stmt body = Evaluate(z); DataType dtype = DataType::Float(32); Var buffer("b", PointerType(PrimType(dtype))); - return Allocate(buffer, dtype, {z, z}, const_true(), body); + return Allocate(buffer, dtype, z, const_true(), body); }; v(fmaketest()); ICHECK_EQ(v.count, 3); @@ -215,7 +215,7 @@ TEST(IRF, StmtMutator) { Stmt body = Evaluate(z); DataType dtype = DataType::Float(32); Var buffer("b", PointerType(PrimType(dtype))); - return Allocate(buffer, dtype, {1, z}, const_true(), body); + return Allocate(buffer, dtype, z, const_true(), body); }; auto fmakeif = [&]() { @@ -229,14 +229,14 @@ TEST(IRF, StmtMutator) { auto body = fmakealloc(); Stmt body2 = Evaluate(1); Stmt bref = body.as()->body; - auto* extentptr = body.as()->extents.get(); + auto* extentptr = body.as()->extent.get(); Array arr{std::move(body), body2, body2}; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() == arrptr); // inplace update body - ICHECK(arr[0].as()->extents[1].same_as(x)); - ICHECK(arr[0].as()->extents.get() == extentptr); + ICHECK(arr[0].as()->extent.same_as(x)); + ICHECK(arr[0].as()->extent.get() == extentptr); // copy because there is additional refs ICHECK(!arr[0].as()->body.same_as(bref)); ICHECK(arr[0].as()->body.as()->value.same_as(x)); @@ -249,8 +249,8 @@ TEST(IRF, StmtMutator) { auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() != arrptr); - ICHECK(arr[0].as()->extents[1].same_as(x)); - ICHECK(!arr2[0].as()->extents[1].same_as(x)); + ICHECK(arr[0].as()->extent.same_as(x)); + ICHECK(!arr2[0].as()->extent.same_as(x)); // mutate but no content change. arr2 = arr; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); @@ -276,7 +276,7 @@ TEST(IRF, StmtMutator) { Stmt body = fmakealloc(); Stmt body2 = Evaluate(1); auto* ref2 = body2.get(); - auto* extentptr = body.as()->extents.get(); + auto* extentptr = body.as()->extent.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); body = SeqStmt({body, body2}); @@ -284,7 +284,7 @@ TEST(IRF, StmtMutator) { body = v(std::move(body)); // the seq get flattened ICHECK(body.as()->size() == 3); - ICHECK(body.as()->seq[0].as()->extents.get() == extentptr); + ICHECK(body.as()->seq[0].as()->extent.get() == extentptr); ICHECK(body.as()->seq[1].get() == ref2); } @@ -292,14 +292,14 @@ TEST(IRF, StmtMutator) { // Cannot cow because of bref Stmt body = fmakealloc(); Stmt body2 = Evaluate(1); - auto* extentptr = body.as()->extents.get(); + auto* extentptr = body.as()->extent.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); auto bref = body; body = SeqStmt({body, body2}); body = v(std::move(body)); // the seq get flattened - ICHECK(body.as()->seq[0].as()->extents.get() != extentptr); + ICHECK(body.as()->seq[0].as()->extent.get() != extentptr); } { @@ -317,8 +317,8 @@ TEST(IRF, StmtMutator) { body = v(std::move(block_realize)); // the body should be changed Block new_block = body.as()->block; - ICHECK(new_block->body.as()->extents[1].same_as(x)); - ICHECK(new_block->init.as()->extents[1].same_as(x)); + ICHECK(new_block->body.as()->extent.same_as(x)); + ICHECK(new_block->init.as()->extent.same_as(x)); ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); From 2a40204b812fdcd20b176cccfc124bacec133cf5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 16 Sep 2021 09:40:05 -0500 Subject: [PATCH 3/9] Updates to IRBuilder, to add functionality needed. Allow the IRBuilder to output BufferRealize/BufferLoad/BufferStore nodes, as well as Load/Store. This minimizes the changes that will be needed in the unit tests, as they can continue using N-d indices and allocations, which are then flattened during the `StorageFlatten` pass. --- python/tvm/tir/ir_builder.py | 197 ++++++++++++++++++++++++++--------- 1 file changed, 150 insertions(+), 47 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 978c630b17ad..19f7e669491a 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,8 +17,9 @@ """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container, PointerType, PrimType +from tvm.ir import container as _container, PointerType, PrimType, Range +from . import buffer as _buffer from . import stmt as _stmt from . import expr as _expr from . import op @@ -38,44 +39,45 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class BufferVar(ObjectGeneric): - """Buffer variable with content type, makes load store easily. +class BufferVarBuilder(ObjectGeneric): + """Helper to build Load/Store interactions with a buffer. - Do not create it directly, create use IRBuilder. + The BufferVarBuilder gives array access into physical memory. + Indices should be flat values, and are used in Load/Store nodes. - BufferVars support array access either via a linear index, or, if given a - shape, via a multidimensional index. + Do not create a BufferVarBuilder directly. Instead, use + `IRBuilder.allocate` or `IRBuilder.pointer`. Examples -------- - In the follow example, x is BufferVar. - :code:`x[0] = ...` directly emit a store to the IRBuilder, + In the follow example, x is BufferVarBuilder. + :code:`x[0] = ...` directly emit a Store to the IRBuilder, :code:`x[10]` translates to Load. .. code-block:: python - # The following code generate IR for x[0] = x[ ib = tvm.tir.ir_builder.create() + + # One-dimensional buffer access x = ib.pointer("float32") x[0] = x[10] + 1 - y = ib.allocate("float32", (32, 32)) - # Array access using a linear index + # Implementing multi-dimensional array access using a linear index + y = ib.allocate("float32", 32*32) y[(2*32) + 31] = 0. - # The same array access using a multidimensional index - y[2, 31] = 0. See Also -------- IRBuilder.pointer IRBuilder.buffer_ptr IRBuilder.allocate + IRBuilder.buffer_realize + """ - def __init__(self, builder, buffer_var, shape, content_type): + def __init__(self, builder, buffer_var, content_type): self._builder = builder self._buffer_var = buffer_var - self._shape = shape self._content_type = content_type def asobject(self): @@ -85,27 +87,20 @@ def asobject(self): def dtype(self): return self._content_type - def _linear_index(self, index): - if not isinstance(index, tuple) or self._shape is None: - return index - assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % ( - len(index), - len(self._shape), - ) - dim_size = 1 - lidx = 0 - for dim, idx in zip(reversed(self._shape), reversed(index)): - lidx += idx * dim_size - dim_size *= dim - return lidx - - def __getitem__(self, index): + def _normalize_index(self, index): t = DataType(self._content_type) - index = self._linear_index(index) if t.lanes > 1: base = index * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) index = _expr.Ramp(base, stride, t.lanes) + + if isinstance(index, _expr.IterVar): + index = index.var + + return index + + def __getitem__(self, index): + index = self._normalize_index(index) return _expr.Load(self._content_type, self._buffer_var, index) def __setitem__(self, index, value): @@ -114,13 +109,92 @@ def __setitem__(self, index, value): raise ValueError( "data type does not match content type %s vs %s" % (value.dtype, self._content_type) ) - index = self._linear_index(index) + + index = self._normalize_index(index) + self._builder.emit(_stmt.Store(self._buffer_var, value, index)) + + +class BufferBuilder(ObjectGeneric): + """Helper to build BufferLoad/BufferStore interactions with a buffer. + + The BufferBuilder gives multi-dimensional array access into + logical memory. Indices should have the same number of dimensions + as the underlying buffer. Read/writes to the BufferBuilder + correspond to BufferLoad/BufferStore nodes. For physical memory + access, see BufferVarBuilder. + + Do not create a BufferBuilder directly. Instead, use + `IRBuilder.buffer_realize` or `IRBuilder.buffer_ptr`. + + Examples + -------- + In the follow example, x is BufferVarBuilder. + :code:`x[0] = ...` directly emit a BufferStore to the IRBuilder, + :code:`x[10]` translates to BufferLoad. + + .. code-block:: python + + ib = tvm.tir.ir_builder.create() + # One-dimensional buffer access + x = ib.buffer_realize("float32", 16) + x[0] = x[10] + 1.0 + + # Multi-dimensional buffer access + y = ib.buffer_realize("float32", (16, 32)) + # Array access using a multidimensional index + y[2, 31] = 0.0 + + See Also + -------- + IRBuilder.pointer + IRBuilder.buffer_ptr + IRBuilder.allocate + IRBuilder.buffer_realize + + """ + + def __init__(self, builder, buffer, content_type): + self._builder = builder + self._buffer = buffer + self._content_type = content_type + + def asobject(self): + return self._buffer + + @property + def dtype(self): + return self._content_type + + def _normalize_index(self, index): + try: + index = [*index] + except TypeError: + index = [index] + t = DataType(self._content_type) if t.lanes > 1: - base = index * t.lanes + base = index[-1] * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) - index = _expr.Ramp(base, stride, t.lanes) - self._builder.emit(_stmt.Store(self._buffer_var, value, index)) + index[-1] = _expr.Ramp(base, stride, t.lanes) + + index = [x.var if isinstance(x, _expr.IterVar) else x for x in index] + + return index + + def __getitem__(self, index): + index = self._normalize_index(index) + return _expr.BufferLoad(self._buffer, index) + + def __setitem__(self, index, value): + index = self._normalize_index(index) + + value = convert(value) + if value.dtype != self._content_type: + raise ValueError( + "data type does not match content type %s vs %s" % (value.dtype, self._content_type) + ) + + self._builder.emit(_stmt.BufferStore(self._buffer, value, index)) class IRBuilder(object): @@ -281,7 +355,7 @@ def while_loop(self, condition): .. code-block:: python ib = tvm.tir.ir_builder.create() - iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + iterations = ib.allocate("int32", 1, name="iterations", scope="local") with ib.while_loop(iterations[0] < 10): iterations[0] += 1 """ @@ -394,7 +468,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=""): + def allocate(self, dtype, extent, name="buf", scope=""): """Create a allocate statement. Parameters @@ -402,8 +476,8 @@ def allocate(self, dtype, shape, name="buf", scope=""): dtype : str The content data type. - shape : tuple of Expr - The shape of array to be allocated. + extent : Expr + The size of array to be allocated. name : str, optional The name of the buffer. @@ -417,10 +491,39 @@ def allocate(self, dtype, shape, name="buf", scope=""): The buffer var representing the buffer. """ buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) + self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, extent, const(1, dtype="uint1"), x)) + return BufferVarBuilder(self, buffer_var, dtype) + + def buffer_realize(self, dtype, shape, name="buf", scope=""): + """Create a BufferRealize statement. + + Parameters + ---------- + dtype : str + The content data type. + + shape : Union[Expr, List[Expr], Tuple[Expr]] + The shape of array to be allocated. + + name : str, optional + The name of the buffer. + + scope : str, optional + The scope of the buffer. + + Returns + ------- + buffer : BufferBuilder + The buffer var representing the buffer. + """ + buffer = _buffer.decl_buffer(shape, dtype=dtype, name=name, scope=scope) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) - return BufferVar(self, buffer_var, shape, dtype) + + bounds = [Range(0, dim_extent) for dim_extent in shape] + + self.emit(lambda x: _stmt.BufferRealize(buffer, bounds, True, x)) + return BufferBuilder(self, buffer, dtype) def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. @@ -438,14 +541,14 @@ def pointer(self, content_type, name="ptr", scope=""): Returns ------- - ptr : BufferVar + ptr : BufferVarBuilder The buffer var representing the buffer. """ buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) - return BufferVar(self, buffer_var, None, content_type) + return BufferVarBuilder(self, buffer_var, content_type) - def buffer_ptr(self, buf, shape=None): - """Create pointer variable corresponds to buffer ptr. + def buffer_ptr(self, buf): + """Create a handle to interact with the buffer specified. Parameters ---------- @@ -457,10 +560,10 @@ def buffer_ptr(self, buf, shape=None): Returns ------- - ptr : BufferVar + ptr : BufferBuilder The buffer var representing the buffer. """ - return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype) + return BufferBuilder(self, buf, buf.dtype) def likely(self, expr): """Add likely tag for expression. From 957cdd1a34f050e83e5ffb4c9c9c24b06ec685f8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Oct 2021 12:58:04 -0500 Subject: [PATCH 4/9] Updated topi schedules to use singular extent. --- python/tvm/topi/cuda/nms.py | 6 ++---- python/tvm/topi/cuda/rcnn/proposal.py | 8 ++++---- python/tvm/topi/cuda/scan.py | 14 +++++++------- python/tvm/topi/cuda/scatter.py | 6 ++---- python/tvm/topi/cuda/sort.py | 22 +++++++++++----------- python/tvm/topi/cuda/sparse.py | 16 +++++++++------- python/tvm/topi/cuda/sparse_reshape.py | 16 +++++++--------- python/tvm/topi/nn/sparse.py | 8 ++++---- python/tvm/topi/sparse/csrmm.py | 2 +- python/tvm/topi/sparse/csrmv.py | 2 +- python/tvm/topi/sparse/dense.py | 4 ++-- python/tvm/topi/sparse_reshape.py | 16 +++++++--------- python/tvm/topi/vision/nms.py | 4 ++-- python/tvm/topi/vision/nms_util.py | 4 ++-- python/tvm/topi/vision/rcnn/proposal.py | 8 ++++---- 15 files changed, 65 insertions(+), 71 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index e402c5888978..b6384cfb00bb 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -306,9 +306,7 @@ def _nms_loop( ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(tx, "thread_extent", nthread_tx) - num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" - ) + num_valid_boxes_local = ib.allocate("int32", 1, name="num_valid_boxes_local", scope="local") num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, i, j, nkeep): @@ -345,7 +343,7 @@ def nms_inner_loop(ib, i, j, nkeep): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms # No need to do more iteration if we have already reached max_output_size boxes - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx = ib.allocate("int32", 1, name="box_idx", scope="local") box_idx[0] = 0 with ib.while_loop( tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) diff --git a/python/tvm/topi/cuda/rcnn/proposal.py b/python/tvm/topi/cuda/rcnn/proposal.py index 12f7a23abe35..6bb2e5c2054c 100644 --- a/python/tvm/topi/cuda/rcnn/proposal.py +++ b/python/tvm/topi/cuda/rcnn/proposal.py @@ -176,8 +176,8 @@ def argsort_ir(data_buf, out_index_buf): ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "virtual_thread", nthread_bx) tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + temp_data = ib.allocate("float32", 1, name="temp_data", scope="local") + temp_index = ib.allocate("int32", 1, name="temp_index", scope="local") idxm = tvm.tir.indexmod @@ -299,14 +299,14 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): tx = te.thread_axis("threadIdx.x") ib = tvm.tir.ir_builder.create() ib.scope_attr(tx, "thread_extent", nthread_tx) - i = ib.allocate("int32", (1,), "i", scope="local") + i = ib.allocate("int32", 1, "i", scope="local") i[0] = 0 p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf) p_remove = ib.buffer_ptr(remove_mask_buf) p_out = ib.buffer_ptr(out_buf) b = tx - nkeep = ib.allocate("int32", (1,), "nkeep", scope="local") + nkeep = ib.allocate("int32", 1, "nkeep", scope="local") nkeep[0] = 0 # number of bbox after nms with ib.for_range(0, num_bbox) as j: diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0d19a92f2058..6c2151f7b573 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -121,9 +121,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") + start = ib.allocate("int64", 1, name="start", scope="local") + middle = ib.allocate("int64", 1, name="middle", scope="local") + end = ib.allocate("int64", 1, name="end", scope="local") start[0] = width * tid with ib.if_scope(start[0] < scan_axis_size): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) @@ -159,10 +159,10 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate(out_dtype, (1,), name="end", scope="local") + start = ib.allocate("int64", 1, name="start", scope="local") + middle = ib.allocate("int64", 1, name="middle", scope="local") + end = ib.allocate("int64", 1, name="end", scope="local") + tmp = ib.allocate(out_dtype, 1, name="end", scope="local") start[0] = width * tid with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fa7545cd323a..09bc48404680 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -642,7 +642,7 @@ def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): ni = indices.shape[0] - atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate(updates.dtype, 1, name="atomic_add_return", scope="local") with ib.new_scope(): nthread_bx = ceil_div(ni, nthread_tx) @@ -772,9 +772,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - atomic_add_return = ib.allocate( - updates.dtype, (1,), name="atomic_add_return", scope="local" - ) + atomic_add_return = ib.allocate(updates.dtype, 1, name="atomic_add_return", scope="local") fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 25cc7a4e2cfb..9cbe89947fb2 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -134,25 +134,25 @@ def _odd_even_sort( ## Create shared memory as syncable thread scratch space tmp_keys_swap = ib.allocate( keys_swap.dtype, - (block_size,), + block_size, name="temp_keys_swap", scope="shared", ) if values_swap is not None: tmp_values_swap = ib.allocate( values_swap.dtype, - (block_size,), + block_size, name="temp_values_swap", scope="shared", ) ## Create thread local data for swapping - temp_keys = ib.allocate(keys_swap.dtype, (1,), name="temp_keys", scope="local") + temp_keys = ib.allocate(keys_swap.dtype, 1, name="temp_keys", scope="local") if values_swap is not None: - temp_values = ib.allocate(values_swap.dtype, (1,), name="temp_values", scope="local") + temp_values = ib.allocate(values_swap.dtype, 1, name="temp_values", scope="local") - temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local") - temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local") + temp_cond1 = ib.allocate(keys_swap.dtype, 1, name="temp_cond1", scope="local") + temp_cond2 = ib.allocate(keys_swap.dtype, 1, name="temp_cond2", scope="local") # Copy data to scratch space base_idx = by * size * axis_mul_after + bz with ib.for_range(0, 2) as n: @@ -255,9 +255,9 @@ def compare(a, b): upper_lim = ceil_log2(size) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): - first = ib.allocate("int64", (1,), name="first", scope="local") - mid = ib.allocate("int64", (1,), name="mid", scope="local") - last = ib.allocate("int64", (1,), name="last", scope="local") + first = ib.allocate("int64", 1, name="first", scope="local") + mid = ib.allocate("int64", 1, name="mid", scope="local") + last = ib.allocate("int64", 1, name="last", scope="local") first[0] = tvm.te.max(0, diag - bCount) last[0] = tvm.te.min(diag, aCount) with ib.while_loop(first[0] < last[0]): @@ -286,8 +286,8 @@ def serial_merge( first, last, ): - i = ib.allocate("int64", (1,), name="i", scope="local") - j = ib.allocate("int64", (1,), name="j", scope="local") + i = ib.allocate("int64", 1, name="i", scope="local") + j = ib.allocate("int64", 1, name="j", scope="local") i[0] = aStart + first j[0] = bStart + diag - last with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count: diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 70baef923bb3..866ac58eaf03 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -180,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k)) + w_data_ptr = ib.buffer_ptr(w_data) w_indices_ptr = ib.buffer_ptr(w_indices) w_indptr_ptr = ib.buffer_ptr(w_indptr) @@ -192,18 +192,20 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): rowlength_bo = ceil_div(w_indptr_ptr[n_index + 1] - row_start, rowlength_bi) # thread local storage for bs_m x bs_n block - block = ib.allocate(data.dtype, (bs_m, bs_n), name="block", scope="local") - data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local") + block = ib.buffer_realize(data.dtype, (bs_m, bs_n), name="block", scope="local") + data_cache = ib.buffer_realize( + data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local" + ) if use_warp_storage: - indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") - w_data_cache = ib.allocate( + indices = ib.allocate(w_indices.dtype, rowlength_bi, name="indices", scope="warp") + w_data_cache = ib.buffer_realize( w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" ) else: - indices = ib.allocate( + indices = ib.buffer_realize( w_indices.dtype, (ni, rowlength_bi), name="indices", scope="shared" ) - w_data_cache = ib.allocate( + w_data_cache = ib.buffer_realize( w_data.dtype, (ni, rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="shared" ) diff --git a/python/tvm/topi/cuda/sparse_reshape.py b/python/tvm/topi/cuda/sparse_reshape.py index 7a796fa42696..53161a2126cb 100644 --- a/python/tvm/topi/cuda/sparse_reshape.py +++ b/python/tvm/topi/cuda/sparse_reshape.py @@ -88,22 +88,20 @@ def gen_ir( new_shape_size = new_shape_ptr.shape[0] multipliers = ib.allocate( - new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="global" - ) - dividers = ib.allocate( - new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="global" + new_shape_ptr.dtype, prev_shape_size, name="multipliers", scope="global" ) + dividers = ib.allocate(new_shape_ptr.dtype, new_shape_size, name="dividers", scope="global") flattened_indices = ib.allocate( new_shape_ptr.dtype, - (sparse_indices_ptr.shape[0],), + sparse_indices_ptr.shape[0], name="flattened_indices", scope="global", ) - total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="global") + total_ele = ib.allocate(new_shape_ptr.dtype, 1, name="total_ele", scope="global") division_total_ele = ib.allocate( - new_shape_ptr.dtype, (1,), name="division_total_ele", scope="global" + new_shape_ptr.dtype, 1, name="division_total_ele", scope="global" ) - equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="global") + equal_shape = ib.allocate("bool", 1, name="equal_shape", scope="global") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): # The computation in this block is very very miniscule since we are just iterating over @@ -183,7 +181,7 @@ def gen_ir( with ib.if_scope(row_number < sparse_indices_ptr.shape[0]): current_element = ib.allocate( - new_shape_ptr.dtype, (1,), name="current_element", scope="local" + new_shape_ptr.dtype, 1, name="current_element", scope="local" ) current_element[0] = flattened_indices[row_number] diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index e577104c3ddc..eb7a09a369ed 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -303,8 +303,8 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): with irb.for_range(0, nnz, kind="serial", name="nz_idx") as nz_idx: out_indptr_ptr[indices_ptr[nz_idx]] += 1 - cumsum = irb.allocate("int32", (1,), name="cumsum", scope="local") - temp = irb.allocate("int32", (1,), name="temp", scope="local") + cumsum = irb.allocate("int32", 1, name="cumsum", scope="local") + temp = irb.allocate("int32", 1, name="temp", scope="local") cumsum[0] = 0 with irb.for_range(0, n, kind="serial", name="col") as col: temp[0] = out_indptr_ptr[col] @@ -325,8 +325,8 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): out_data_ptr[dest] = data_ptr[real_idx] out_indptr_ptr[col] += 1 - last = irb.allocate("int32", (1,), name="last", scope="local") - temp2 = irb.allocate("int32", (1,), name="temp2", scope="local") + last = irb.allocate("int32", 1, name="last", scope="local") + temp2 = irb.allocate("int32", 1, name="temp2", scope="local") last[0] = 0 with irb.for_range(0, n, kind="serial", name="col") as col: temp2[0] = out_indptr_ptr[col] diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py index 4d659c801103..31b6e4c06c90 100644 --- a/python/tvm/topi/sparse/csrmm.py +++ b/python/tvm/topi/sparse/csrmm.py @@ -81,7 +81,7 @@ def csrmm_default_ir(data, indices, indptr, weight, out): _, N = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="row") as row: - dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + dot = irb.allocate(data.dtype, 1, name="dot", scope="local") out_ptr[row * N + n] = cast(0, data.dtype) dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py index 3c2016c6513a..7d54133a5a5f 100644 --- a/python/tvm/topi/sparse/csrmv.py +++ b/python/tvm/topi/sparse/csrmv.py @@ -71,7 +71,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) num_rows = indptr.shape[0] - 1 with irb.for_range(0, num_rows, kind="parallel", name="row") as row: - dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + dot = irb.allocate(data.dtype, 1, name="dot", scope="local") out_ptr[row] = cast(0, data.dtype) dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] diff --git a/python/tvm/topi/sparse/dense.py b/python/tvm/topi/sparse/dense.py index 5c63e44f691a..e40ed15d2535 100644 --- a/python/tvm/topi/sparse/dense.py +++ b/python/tvm/topi/sparse/dense.py @@ -76,7 +76,7 @@ def dense_default_ir(data, indices, indptr, weight, out): N, K = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="m") as m: - dot = irb.allocate(dtype, (1,), name="dot", scope="local") + dot = irb.allocate(dtype, 1, name="dot", scope="local") out_ptr[m * N + n] = tvm.tir.const(0, dtype) dot[0] = tvm.tir.const(0, dtype) row_start = indptr_ptr[m] @@ -155,7 +155,7 @@ def dense_default_ir(data, w_data, w_indices, w_indptr, out): N = simplify(w_indptr.shape[0] - 1) with irb.for_range(0, M, kind="vectorize", name="m") as m: with irb.for_range(0, N, kind="parallel", name="n") as n: - dot = irb.allocate(dtype, (1,), name="dot", scope="local") + dot = irb.allocate(dtype, 1, name="dot", scope="local") out_ptr[m * N + n] = tvm.tir.const(0, dtype) dot[0] = tvm.tir.const(0, dtype) row_start = w_indptr_ptr[n] diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index b25bd854a7f9..7dc3bb2d96b2 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -89,19 +89,17 @@ def gen_ir( new_shape_size = new_shape_ptr.shape[0] multipliers = ib.allocate( - new_shape_ptr.dtype, (prev_shape_size,), name="multipliers", scope="local" - ) - dividers = ib.allocate( - new_shape_ptr.dtype, (new_shape_size,), name="dividers", scope="local" + new_shape_ptr.dtype, prev_shape_size, name="multipliers", scope="local" ) + dividers = ib.allocate(new_shape_ptr.dtype, new_shape_size, name="dividers", scope="local") flattened_indices = ib.allocate( new_shape_ptr.dtype, - (sparse_indices_ptr.shape[0],), + sparse_indices_ptr.shape[0], name="flattened_indices", scope="local", ) - total_ele = ib.allocate(new_shape_ptr.dtype, (1,), name="total_ele", scope="local") + total_ele = ib.allocate(new_shape_ptr.dtype, 1, name="total_ele", scope="local") total_ele[0] = prev_shape[0] # Cumulative Reverse Exclusive Multiply @@ -114,7 +112,7 @@ def gen_ir( total_ele[0] *= prev_shape[prev_shape_size - i] division_total_ele = ib.allocate( - new_shape_ptr.dtype, (1,), name="division_total_ele", scope="local" + new_shape_ptr.dtype, 1, name="division_total_ele", scope="local" ) division_total_ele[0] = Cast(new_shape_ptr.dtype, 1) with ib.for_range(0, new_shape_size) as i: @@ -130,7 +128,7 @@ def gen_ir( with ib.else_scope(): out_new_shape[i] = new_shape[i] - equal_shape = ib.allocate("bool", (1,), name="equal_shape", scope="local") + equal_shape = ib.allocate("bool", 1, name="equal_shape", scope="local") # Check if prev_shape and new_shape are equal equal_shape[0] = True @@ -163,7 +161,7 @@ def gen_ir( with ib.for_range(0, new_sparse_indices_ptr.shape[0], kind="parallel") as i: current_element = ib.allocate( - new_shape_ptr.dtype, (1,), name="current_element", scope="local" + new_shape_ptr.dtype, 1, name="current_element", scope="local" ) current_element[0] = flattened_indices[i] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 7a51946a279a..6acaa9a3e299 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -655,9 +655,9 @@ def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" + "int32", 1, name="num_valid_boxes_local", scope="local" ) - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx = ib.allocate("int32", 1, name="box_idx", scope="local") num_valid_boxes_local[0] = 0 box_idx[0] = 0 diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index d12592fd111a..7d55468571d4 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -60,8 +60,8 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): def binary_search(ib, y, num_boxes, scores, score_threshold, out): """Binary search for score_threshold on scores sorted in descending order""" - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") + lo = ib.allocate("int32", 1, name="lo", scope="local") + hi = ib.allocate("int32", 1, name="hi", scope="local") lo[0] = 0 hi[0] = num_boxes diff --git a/python/tvm/topi/vision/rcnn/proposal.py b/python/tvm/topi/vision/rcnn/proposal.py index 12a0d6bcf0a0..23b0d5d39ebf 100644 --- a/python/tvm/topi/vision/rcnn/proposal.py +++ b/python/tvm/topi/vision/rcnn/proposal.py @@ -205,8 +205,8 @@ def argsort_ir(data_buf, out_index_buf): ib = tvm.tir.ir_builder.create() p_data = ib.buffer_ptr(data_buf) index_out = ib.buffer_ptr(out_index_buf) - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + temp_data = ib.allocate("float32", 1, name="temp_data", scope="local") + temp_index = ib.allocate("int32", 1, name="temp_index", scope="local") idxm = tvm.tir.indexmod with ib.for_range(0, batch, kind="unroll") as b: start = b * num_bbox @@ -316,12 +316,12 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape) rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch ib = tvm.tir.ir_builder.create() - i = ib.allocate("int32", (batch,), "i", scope="local") + i = ib.allocate("int32", batch, "i", scope="local") p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf) p_remove = ib.buffer_ptr(remove_mask_buf) p_out = ib.buffer_ptr(out_buf) - nkeep = ib.allocate("int32", (batch,), "nkeep", scope="local") + nkeep = ib.allocate("int32", batch, "nkeep", scope="local") with ib.for_range(0, batch) as b: nkeep[b] = 0 From 7c86eee03446aa8f12965df684f4e1e6877ce4d0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Oct 2021 12:57:36 -0500 Subject: [PATCH 5/9] Updated python unittests to use singular extent. --- .../backend/contrib/ethosu/tir/passes.py | 13 +- .../contrib/ethosu/tir_to_cs_translator.py | 2 +- .../test_ethosu/test_encode_constants.py | 12 +- .../test_ethosu/test_replace_conv2d.py | 8 +- .../contrib/test_ethosu/test_replace_copy.py | 4 +- .../test_ethosu/test_tir_to_cs_translator.py | 18 +- .../unittest/test_target_codegen_llvm.py | 4 +- .../unittest/test_target_codegen_vulkan.py | 10 +- tests/python/unittest/test_te_schedule_ops.py | 2 +- .../test_tir_analysis_calculate_workspace.py | 18 +- ...t_tir_analysis_detect_buffer_access_lca.py | 2 +- tests/python/unittest/test_tir_constructor.py | 4 +- tests/python/unittest/test_tir_ir_builder.py | 18 +- tests/python/unittest/test_tir_nodes.py | 6 +- ..._tir_transform_convert_for_loops_serial.py | 14 +- .../test_tir_transform_flatten_buffer.py | 12 +- ...test_tir_transform_inject_double_buffer.py | 7 +- ...est_tir_transform_inject_virtual_thread.py | 201 ++++++++++-------- .../test_tir_transform_lower_tvm_builtin.py | 4 +- .../test_tir_transform_lower_warp_memory.py | 25 ++- .../test_tir_transform_make_unpacked_api.py | 2 +- ...merge_dynamic_shared_memory_allocations.py | 24 ++- .../test_tir_transform_storage_flatten.py | 12 +- .../test_tir_transform_storage_rewrite.py | 16 +- .../test_tir_transform_thread_sync.py | 4 +- .../unittest/test_tir_transform_vectorize.py | 2 +- .../unittest/test_tvmscript_error_report.py | 4 +- .../unittest/test_tvmscript_roundtrip.py | 18 +- vta/python/vta/transform.py | 4 +- 29 files changed, 254 insertions(+), 216 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 8bb410e986c7..5199d8c37579 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -58,7 +58,7 @@ def ReplaceOperators(): pointer_to_producer = {} pointer_to_consumer = {} replace_output_pointer = {} - pointer_to_extents = {} + pointer_to_extent = {} def _resolve_pointers(stmt): """This pass determines information about the pointers present in the IR. @@ -75,7 +75,7 @@ def _get_loads(stmt): loads.append(stmt.buffer_var) if isinstance(stmt, tvm.tir.Allocate): - pointer_to_extents[stmt.buffer_var] = stmt.extents + pointer_to_extent[stmt.buffer_var] = stmt.extent if isinstance(stmt.body[0], tvm.tir.AttrStmt): if stmt.body[0].attr_key == "pragma_op": pointer_to_producer[stmt.buffer_var] = stmt.body[0] @@ -160,7 +160,7 @@ def _replace_pointers(stmt): # If the pointer doesn't have an extent registered to it, # this means the pointer is to a Buffer. In this case, we # just want to delete the memory scope attribute - if replace_pointer not in pointer_to_extents: + if replace_pointer not in pointer_to_extent: return stmt.body # Otherwise, rewrite the memory scope attribute with the new pointer return tvm.tir.AttrStmt( @@ -174,12 +174,12 @@ def _replace_pointers(stmt): # If the pointer doesn't have an extent registered to it, # this means the pointer is to a Buffer. In this case, we # just want to delete the allocation statement - if replace_pointer not in pointer_to_extents: + if replace_pointer not in pointer_to_extent: return stmt.body # Otherwise, rewrite the allocation statement with the new pointer # and the new extent replace_type = replace_pointer.type_annotation.element_type.dtype - replace_extents = pointer_to_extents[replace_pointer] + replace_extents = pointer_to_extent[replace_pointer] return tvm.tir.Allocate( replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body ) @@ -404,10 +404,11 @@ def _visit_rewrite(stmt): if pointer_to_buffer[allocate_pointer] in rewrite_buffer: new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]] new_pointer = rewrite_pointer[allocate_pointer] + assert len(new_buffer.shape) == 1 return tvm.tir.Allocate( new_pointer, new_buffer.dtype, - new_buffer.shape, + new_buffer.shape[0], stmt.condition, stmt.body, stmt.span, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 408eab6427ca..72c9661a8df9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -167,7 +167,7 @@ def populate_allocate_buffer_info(stmt): allocate = stmt buffer_info[allocate.buffer_var] = BufferInfo( None, - allocate.extents, + [allocate.extent], allocate.dtype, BufferType.scratch, ) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 60ed352edcfd..3a2f1411bf0e 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -45,8 +45,8 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([128], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + placeholder_global = T.allocate(128, "uint8", "global") + placeholder_d_global = T.allocate(32, "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) @@ -119,7 +119,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") + ethosu_write_2 = T.allocate(4096, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None @@ -187,9 +187,9 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + ethosu_write_2 = T.allocate(4096, "int8", "global") + placeholder_global = T.allocate(80, "uint8", "global") + placeholder_d_global = T.allocate(32, "uint8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index f76a59dd1eb3..7c172c18d1d6 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -202,7 +202,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1024], "int8", "global") + ethosu_write_2 = T.allocate(1024, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) @@ -223,7 +223,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1536], "int8", "global") + ethosu_write_2 = T.allocate(1536, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) @@ -244,7 +244,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2560], "int8", "global") + ethosu_write_2 = T.allocate(2560, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) @@ -267,7 +267,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2304], "int8", "global") + ethosu_write_2 = T.allocate(2304, "int8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 76b7ef2a70ee..2c869ea51849 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -39,8 +39,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([304], "uint8", "global") - placeholder_d_global = T.allocate([80], "uint8", "global") + placeholder_global = T.allocate(304, "uint8", "global") + placeholder_d_global = T.allocate(80, "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 8240b392a1cf..79d854cf3bc5 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -60,8 +60,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_8 = T.match_buffer(placeholder_2, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_4, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) # body - ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") - ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") + ethosu_conv2d_2 = T.allocate(1024, "uint8", "global") + ethosu_conv2d_3 = T.allocate(2048, "uint8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7.data, 0), 0, 12, T.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9.data, 0), 0, 12, T.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7.data, 0), 0, 12, T.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) @@ -82,8 +82,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_5 = T.match_buffer(placeholder_2, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) placeholder_4 = T.match_buffer(placeholder_1, [8, 1, 1, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([256], "uint8", "global") - placeholder_d_global = T.allocate([8], "int32", "global") + placeholder_global = T.allocate(256, "uint8", "global") + placeholder_d_global = T.allocate(8, "int32", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4.data, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5.data, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, 12, T.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) @@ -109,8 +109,8 @@ def main(placeholder: T.handle, ethosu_conv2d: T.handle, placeholder_1: T.handle placeholder_9 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer = T.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([144], "uint8", "global") - placeholder_d_global = T.allocate([20], "uint8", "global") + placeholder_global = T.allocate(144, "uint8", "global") + placeholder_d_global = T.allocate(20, "uint8", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 144, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 144, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) @@ -148,9 +148,9 @@ def main(placeholder: T.handle, placeholder_1: T.handle, ethosu_conv2d: T.handle buffer_4 = T.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_8 = T.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_conv2d_2 = T.allocate([4096], "uint8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([20], "uint8", "global") + ethosu_conv2d_2 = T.allocate(4096, "uint8", "global") + placeholder_global = T.allocate(80, "uint8", "global") + placeholder_d_global = T.allocate(20, "uint8", "global") T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_5.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 8c8d601672ac..dd3508d2ed6b 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -761,7 +761,7 @@ def test_llvm_lower_atomic(): def do_atomic_add(A): ib = tvm.tir.ir_builder.create() n = A.shape[0] - atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate(A.dtype, 1, name="atomic_add_return", scope="local") one = tvm.tir.const(1, A.dtype) A_ptr = ib.buffer_ptr(A) with ib.for_range(0, n, name="i", kind="parallel") as i: @@ -787,7 +787,7 @@ def test_llvm_gpu_lower_atomic(): def do_atomic_add(A): ib = tvm.tir.ir_builder.create() n = A.shape[0] - atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate(A.dtype, 1, name="atomic_add_return", scope="local") one = tvm.tir.const(1, A.dtype) A_ptr = ib.buffer_ptr(A) nthread_tx = 64 diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 1edc5d311759..bf29b3cd20f2 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -333,7 +333,7 @@ def do_compute(A, B, n): if "gpu" in target.keys: ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) - iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + iterations = ib.allocate("int32", 1, name="iterations", scope="local") iterations[0] = 0 B[0] = 0 @@ -503,10 +503,10 @@ def do_compute(ins, outs): store_index = index_map[store_type] if indirect_indices: - load_index = tvm.tir.expr.Load("int32x4", R, load_index) + load_index = tvm.tir.expr.Load("int32x4", R.asobject().data, load_index) - transfer = tvm.tir.expr.Load("int32x4", A, load_index) - ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + transfer = tvm.tir.expr.Load("int32x4", A.asobject().data, load_index) + ib.emit(tvm.tir.stmt.Store(B.asobject().data, transfer, store_index)) return ib.get() @@ -536,7 +536,7 @@ def do_compute(ins, outs): ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) - array = ib.allocate("int32", (alloc_nbytes,), name="array", scope="shared") + array = ib.allocate("int32", alloc_nbytes, name="array", scope="shared") array[0] = 0 out[0] = array[0] diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index bc4bc4f56e19..db7fd7b1624b 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -607,7 +607,7 @@ def collect_visit(stmt, f): def visit_stmt(op): if isinstance(op, tvm.tir.Allocate): - return op.extents[0].value == 97 + return op.extent.value == 97 return False assert not any(collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse))) diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 4b61625014e2..9a7250f2178c 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -31,8 +31,8 @@ def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.hand placeholder_149 = T.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_49 = T.match_buffer(T_cast_48, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_22 = T.allocate([131072], "int16", "global") - DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") + PaddedInput_22 = T.allocate(131072, "int16", "global") + DepthwiseConv2d_9 = T.allocate(100352, "int32", "global") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), T.load("int16", placeholder_147.data, ((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)), T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): @@ -62,25 +62,25 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl placeholder_167 = T.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_77 = T.match_buffer(T_cast_76, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_25 = T.allocate([1, 16, 16, 512], "int16", "global") + PaddedInput_25 = T.allocate(1*16*16*512, "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), T.load("int16", placeholder_165.data, ((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)), T.int16(0), dtype="int16") - T_add_11 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: + T_add_11 = T.allocate(1*14*14*512, "int32", "global") + with T.allocate(1*14*14*512, "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (T.load("int32", DepthwiseConv2d_11, (((i_11*7168) + (j_11*512)) + c_11)) + (T.load("int16", PaddedInput_25, (((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)).astype("int32")*T.load("int16", placeholder_166.data, (((di_11*1536) + (dj_11*512)) + c_11)).astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (T.load("int32", DepthwiseConv2d_11, (((ax1_44*7168) + (ax2_45*512)) + ax3_47)) + T.load("int32", placeholder_167.data, ax3_47)) - compute_22 = T.allocate([1, 14, 14, 512], "int32", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: + compute_22 = T.allocate(1*14*14*512, "int32", "global") + with T.allocate(1*14*14*512, "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T.load("int32", T_add_11, (((ax1_45*7168) + (ax2_46*512)) + ax3_48)) for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T.load("int32", T_cast_78, (((i1_36*7168) + (i2_47*512)) + i3_48)), 1948805937, 31, -5, dtype="int32") - T_cast_79 = T.allocate([1, 14, 14, 512], "uint8", "global") - with T.allocate([1, 14, 14, 512], "int32", "global") as compute_23: + T_cast_79 = T.allocate(1*14*14*512, "uint8", "global") + with T.allocate(1*14*14*512, "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(T.load("int32", compute_22, (((i1_37*7168) + (i2_48*512)) + i3_49)), 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1aae8cdd03e1..cf885cd44bad 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -46,7 +46,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: with T.block([]): T.reads([]) T.writes(B[0:16, 0:16]) - A = T.allocate([256], "float32", "global") + A = T.allocate(256, "float32", "global") for i, j in T.grid(16, 16): T.store(A, i * 16 + j, 1) for i in range(0, 16): diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 00aba46ba431..2439864a6c72 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -155,7 +155,7 @@ def test_stmt_constructor(): assert x.value.value == 1 buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) + x = tvm.tir.Allocate(buffer_var, "float32", 10, tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var @@ -163,7 +163,7 @@ def test_stmt_constructor(): storage_scope = "global.texture" buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope)) - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) + x = tvm.tir.Allocate(buffer_var, "float32", 10, tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 5b123e883849..cff266ab0e0b 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -184,7 +184,7 @@ def test_ir(A, B, C): A = ib.buffer_ptr(A) B = ib.buffer_ptr(B) C = ib.buffer_ptr(C) - i = ib.allocate("int32", (1,), name="i", scope="local") + i = ib.allocate("int32", 1, name="i", scope="local") i[0] = 0 with ib.for_range(0, n) as j: @@ -242,8 +242,8 @@ def collatz_ref(n): return i def collatz(ib, n, C): - i = ib.allocate("int32", (1,), name="i", scope="local") - a = ib.allocate("int32", (1,), name="a", scope="local") + i = ib.allocate("int32", 1, name="i", scope="local") + a = ib.allocate("int32", 1, name="a", scope="local") i[0] = 0 a[0] = n with ib.while_loop(a[0] > 1): @@ -317,9 +317,9 @@ def complex_sqr(z): return pixels def mandel(ib, i, j, pixels): - z = ib.allocate("float32", (2,), name="z", scope="local") - tmp = ib.allocate("float32", (1,), name="tmp", scope="local") - iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + z = ib.allocate("float32", 2, name="z", scope="local") + tmp = ib.allocate("float32", 1, name="tmp", scope="local") + iterations = ib.allocate("int32", 1, name="iterations", scope="local") z[0] = (i / float(n) - 1) * 2 z[1] = (j / float(n) - 0.5) * 2 @@ -409,8 +409,8 @@ def check_target(target, ir): def test_while_binary_search(): def binary_search(ib, n, i, Aptr, Bptr, Cptr): - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") + lo = ib.allocate("int32", 1, name="lo", scope="local") + hi = ib.allocate("int32", 1, name="hi", scope="local") lo[0] = 0 hi[0] = n @@ -509,7 +509,7 @@ def test_device_ir(A, B): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) - temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic size + temp = ib.allocate(dtype, n, scope="shared.dyn") # n is symbolic size Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index fe719ee99693..c663dd0dd627 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -481,7 +481,7 @@ def test_tir_allocate(): allocate = tvm.tir.Allocate( buffer_var=a, dtype=dtype, - extents=[2, 2], + extent=4, condition=tvm.get_global_func("tir.const_true")(dtype, None), body=tvm.tir.Evaluate(2 + 1), annotations={ @@ -491,7 +491,7 @@ def test_tir_allocate(): ) assert allocate.buffer_var == a assert allocate.dtype == "int8" - assert list(allocate.extents) == [2, 2] + assert allocate.extent == 4 assert allocate.annotations["attr1"] == "foo" assert allocate.annotations["attr2"] == "bar" @@ -500,7 +500,7 @@ def test_tir_allocate(): output = func.astext() assert ( output.find( - 'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})' + 'allocate(buffer: Pointer(global int8), int8, 4), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})' ) != -1 ) diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index a91fa2591e00..aca80dc91f86 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -31,17 +31,19 @@ def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T. placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") + PaddedInput_3 = T.buffer_decl([1,28,28,192], dtype='int16', scope='global') + T.realize(PaddedInput_3[0:1, 0:28, 0:28, 0:192], '') for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): - T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + PaddedInput_3[0, i0_i1_fused_3, i2_3, i3_3] = placeholder_33[0, i0_i1_fused_3, i2_3, i3_3] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): - Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") - T.store(Conv2dOutput_3, 0, 0, True) + Conv2dOutput_3 = T.buffer_decl([1], dtype='int32', scope='global') + T.realize(Conv2dOutput_3[0:1], '') + Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): - T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) - T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + Conv2dOutput_3[0] = Conv2dOutput_3[0] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3//28, ax0_ax1_fused_ax2_fused_3%28, rc_3], "int32")*T.cast(placeholder_34[0,0,rc_3,ax3_2], "int32") + T_cast_9[ax0_ax1_fused_ax2_fused_3//28, ax0_ax1_fused_ax2_fused_3%28, ax3_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[0,0,0,ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") # fmt: on diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 21c896c7bb7e..6d5241b2a4c9 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -53,7 +53,7 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new = T.allocate([16], "float32", "global") + B_new = T.allocate(16, "float32", "global") for j in T.serial(0, 16): B_new[j] = T.load("float32", A.data, ((i * 16) + j)) + 1.0 for j in T.serial(0, 16): @@ -95,7 +95,7 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") + B = T.allocate(16, "float32", "local") for j in range(0, 16): B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 for j in range(0, 16): @@ -130,7 +130,7 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.allocate([m], "float32", "global") + B = T.allocate(m, "float32", "global") for j in range(0, m): B[j] = T.load("float32", A.data, i * m + j) + 1.0 for j in range(0, m): @@ -203,8 +203,8 @@ def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B = T.allocate((32,), "float32", "global") - C = T.allocate((32,), "float32", "global") + B = T.allocate(32, "float32", "global") + C = T.allocate(32, "float32", "global") B[i] = T.load("float32", A.data, i) + 1.0 C[i] = T.load("float32", A.data, i) + T.load("float32", B, i) D.data[i] = T.load("float32", C, i) * 2.0 @@ -238,7 +238,7 @@ def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in T.serial(0, 4): - B_new = T.allocate([68], "float32", "global") + B_new = T.allocate(68, "float32", "global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = T.load("float32", A.data, i0 * 64 + i1 * 16 + j) + 1.0 diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 9b37bcaaacbc..821dc9a3cffa 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,10 +47,13 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body + # Allocation of B is now twice as large assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert stmt.body.extent.value == 2 * m + + mod = tvm.tir.transform.ThreadSync("shared")(mod) + f = mod["db"] - f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 673267a9b1fa..97f7ca9db617 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -14,95 +14,124 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import sys + +import pytest + import tvm +import tvm.testing from tvm import te +vthread_name = tvm.testing.parameter( + "vthread", + "cthread", +) +buffer_size = tvm.testing.parameter(4) +nthread = tvm.testing.parameter(2) -def test_vthread(): - dtype = "int64" - n = 100 - m = 4 - nthread = 2 - - def get_vthread(name): - tx = te.thread_axis(name) - ty = te.thread_axis(name) - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - C = ib.pointer("float32", name="C") - with ib.for_range(0, n) as i: - ib.scope_attr(tx, "virtual_thread", nthread) - ib.scope_attr(ty, "virtual_thread", nthread) - B = ib.allocate("float32", m, name="B", scope="shared") - B[i] = A[i * nthread + tx] - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) - ib.emit( - tvm.tir.call_extern( - "int32", - "Run", - bbuffer.access_ptr("r"), - tvm.tir.call_intrin("int32", "tir.tvm_context_id"), - ) - ) - C[i * nthread + tx] = B[i] + 1 - return ib.get() - - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"] - - assert stmt.body.body.extents[0].value == 2 - - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"] - - assert len(stmt.body.body.extents) == 3 - - -def test_vthread_extern(): - dtype = "int64" - n = 100 - m = 4 - nthread = 2 - - def get_vthread(name): - tx = te.thread_axis(name) - ty = te.thread_axis(name) - ib = tvm.tir.ir_builder.create() - with ib.for_range(0, n) as i: - ib.scope_attr(tx, "virtual_thread", nthread) - ib.scope_attr(ty, "virtual_thread", nthread) - A = ib.allocate("float32", m, name="A", scope="shared") - B = ib.allocate("float32", m, name="B", scope="shared") - C = ib.allocate("float32", m, name="C", scope="shared") - cbuffer = tvm.tir.decl_buffer((m,), dtype=C.dtype, data=C.asobject()) - abuffer = tvm.tir.decl_buffer((m,), dtype=A.dtype, data=A.asobject()) - bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) - A[tx] = tx + 1.0 - B[ty] = ty + 1.0 - ib.emit( - tvm.tir.call_extern( - "int32", - "Run", - abuffer.access_ptr("r"), - bbuffer.access_ptr("r"), - cbuffer.access_ptr("rw"), - ) + +@tvm.testing.fixture +def vthread_mod(vthread_name, buffer_size, nthread): + loop_extent = 100 + + tx = te.thread_axis(vthread_name) + ty = te.thread_axis(vthread_name) + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + with ib.for_range(0, loop_extent) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + ib.scope_attr(ty, "virtual_thread", nthread) + B = ib.allocate("float32", buffer_size, name="B", scope="shared") + B[i] = A[i * nthread + tx] + bbuffer = tvm.tir.decl_buffer((buffer_size,), dtype=B.dtype, data=B.asobject()) + ib.emit( + tvm.tir.call_extern( + "int32", + "Run", + bbuffer.access_ptr("r"), + tvm.tir.call_intrin("int32", "tir.tvm_context_id"), ) - return ib.get() + ) + C[i * nthread + tx] = B[i] + 1 - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"] + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([], ib.get())) - assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.extents) == 3 +def test_vthread(vthread_mod, vthread_name, buffer_size, nthread): + mod = tvm.tir.transform.InjectVirtualThread()(vthread_mod) + stmt = mod["main"] -def test_vthread_if_then_else(): - nthread = 2 + if vthread_name == "vthread": + # All virtual thread axes that starts with "vthread" share the + # same iteration, similar to threadIdx.x, so the number of + # virtual threads is nthread. + expected_buffer_size = buffer_size * nthread + elif vthread_name == "cthread": + # All other virtual thread axes are independent, so tx and ty + # are independent and the total number of virtual threads is + # nthread*nthread. + expected_buffer_size = buffer_size * nthread * nthread + else: + raise ValueError(f"Unexpected vthread_name: {vthread_name}") + + assert stmt.body.body.extent.value == expected_buffer_size + + +@tvm.testing.fixture +def vthread_extern_mod(vthread_name, buffer_size, nthread): + loop_extent = 100 + + tx = te.thread_axis(vthread_name) + ty = te.thread_axis(vthread_name) + ib = tvm.tir.ir_builder.create() + with ib.for_range(0, loop_extent) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + ib.scope_attr(ty, "virtual_thread", nthread) + A = ib.allocate("float32", buffer_size, name="A", scope="shared") + B = ib.allocate("float32", buffer_size, name="B", scope="shared") + C = ib.allocate("float32", buffer_size, name="C", scope="shared") + cbuffer = tvm.tir.decl_buffer((buffer_size,), dtype=C.dtype, data=C.asobject()) + abuffer = tvm.tir.decl_buffer((buffer_size,), dtype=A.dtype, data=A.asobject()) + bbuffer = tvm.tir.decl_buffer((buffer_size,), dtype=B.dtype, data=B.asobject()) + A[tx] = tx + 1.0 + B[ty] = ty + 1.0 + ib.emit( + tvm.tir.call_extern( + "int32", + "Run", + abuffer.access_ptr("r"), + bbuffer.access_ptr("r"), + cbuffer.access_ptr("rw"), + ) + ) + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([], ib.get())) + + +def test_vthread_extern(vthread_extern_mod, vthread_name, buffer_size, nthread): + mod = tvm.tir.transform.InjectVirtualThread()(vthread_extern_mod) + stmt = mod["main"] + + if vthread_name == "vthread": + # The shared A and B buffers are only exposed as read-only to + # the external function, so they can still share the allocated + # space. + ro_buffer_size = buffer_size * nthread + rw_buffer_size = buffer_size * nthread * nthread + elif vthread_name == "cthread": + ro_buffer_size = buffer_size * nthread * nthread + rw_buffer_size = buffer_size * nthread * nthread + else: + raise ValueError(f"Unexpected vthread_name: {vthread_name}") + + A_alloc = stmt.body.body + C_alloc = A_alloc.body.body + assert A_alloc.extent.value == ro_buffer_size + assert C_alloc.extent.value == rw_buffer_size + + +@tvm.testing.fixture +def vthread_if_then_else_mod(nthread): tx = te.thread_axis("vthread") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -115,17 +144,15 @@ def test_vthread_if_then_else(): B[i] = A[i * nthread + tx] + 1 with ib.if_scope(i == 0): B[i] = A[i * nthread + tx] + 2 - stmt = ib.get() + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([], ib.get())) + - stmt = tvm.tir.transform.InjectVirtualThread()( - tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"] +def test_vthread_if_then_else(vthread_if_then_else_mod): + stmt = tvm.tir.transform.InjectVirtualThread()(vthread_if_then_else_mod)["main"] assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None if __name__ == "__main__": - test_vthread_extern() - test_vthread() - test_vthread_if_then_else() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 63772dea65d7..7b38fc6120c3 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -157,8 +157,8 @@ def build_tir(): Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32")) # return handle # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1] - Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject())) - ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1)) + Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject().data)) + ib.emit(tvm.tir.Store(Aptr_var, tvm.tir.const(expected_value[1], "float32"), 1)) stmt = ib.get() return tvm.IRModule.from_expr( diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 675a7feb3b1f..43c45ba457e7 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -22,15 +22,19 @@ import tvm.testing -@tvm.testing.requires_cuda -def test_lower_warp_memory_local_scope(): - m = 128 - A = te.placeholder((m,), name="A") - B = te.compute((m,), lambda i: A[i] + 3, name="B") +@tvm.testing.parametrize_targets("cuda") +def test_lower_warp_memory_local_scope(target): + target = tvm.target.Target(target) + assert target.thread_warp_size == 32 + + arr_size = 128 + cache_size = 64 + A = te.placeholder((arr_size,), name="A") + B = te.compute((arr_size,), lambda i: A[i] + 3, name="B") s = te.create_schedule(B.op) AA = s.cache_read(A, "warp", [B]) - xo, xi = s[B].split(B.op.axis[0], 64) + xo, xi = s[B].split(B.op.axis[0], cache_size) xi0, xi1 = s[B].split(xi, factor=32) tx = te.thread_axis("threadIdx.x") s[B].bind(xi1, tx) @@ -39,17 +43,16 @@ def test_lower_warp_memory_local_scope(): xo, xi = s[AA].split(s[AA].op.axis[0], 32) s[AA].bind(xi, tx) - cuda_target = tvm.target.Target("cuda") - assert cuda_target.thread_warp_size == 32 mod = tvm.lower(s, [A, B], name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + mod = tvm.tir.transform.LowerWarpMemory()(mod) + fdevice = mod["f_kernel0"] allocate = fdevice.body.body assert allocate.buffer_var.type_annotation.storage_scope == "local" - assert fdevice.body.body.extents[0].value == 2 + assert fdevice.body.body.extent.value * target.thread_warp_size == cache_size @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 9d917466758b..649e7e6064d5 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -132,7 +132,7 @@ def test_body(): ib = tvm.tir.ir_builder.create() A = tvm.tir.decl_buffer(name="A", shape=[1]) B = tvm.tir.decl_buffer(name="B", shape=[1]) - C = ib.buffer_ptr(A) + C = ib.buffer_ptr(A.data) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, B, C], stmt)) diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9c511f1de6b9..00ec510a9759 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -41,21 +41,23 @@ def run_passes(sch, args): def verify_single_allocation(stmt, alloc_size=None): num_alloc = [0] - alloc_extents = [] + alloc_extent = 1 def verify(n): + nonlocal alloc_extent + if ( isinstance(n, tvm.tir.Allocate) and n.buffer_var.type_annotation.storage_scope == "shared.dyn" ): num_alloc[0] += 1 - alloc_extents.append(n.extents[0]) + alloc_extent *= n.extent tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 if alloc_size: - assert alloc_extents[0] == alloc_size + assert alloc_extent == alloc_size @tvm.testing.requires_gpu @@ -80,12 +82,12 @@ def test_matmul_ir(A, B, C): ib.scope_attr(bx, "thread_extent", n // block) ib.scope_attr(by, "thread_extent", n // block) - A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 - B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 + A_sh = ib.buffer_realize(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.buffer_realize(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. - C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 + C_sh = ib.buffer_realize(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) @@ -155,8 +157,8 @@ def test_device_ir(A, B, C): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) - A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 - B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 + A_sh = ib.allocate(A.dtype, n, scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, n, scope="shared.dyn") # fp32 Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) @@ -218,9 +220,9 @@ def test_device_ir(A, B, C, D): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) - A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn", name="A_sh") - B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn", name="B_sh") - C_sh = ib.allocate(C.dtype, (C.shape[0],), scope="shared.dyn", name="C_sh") + A_sh = ib.allocate(A.dtype, n, scope="shared.dyn", name="A_sh") + B_sh = ib.allocate(B.dtype, n, scope="shared.dyn", name="B_sh") + C_sh = ib.allocate(C.dtype, C.shape[0], scope="shared.dyn", name="C_sh") Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 37223493a8b5..ba508a6c0b4a 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -81,25 +81,25 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.extents[0].value == 17 * 8 + assert stmt.extent.value == 17 * 8 def test_flatten_double_buffer(): dtype = "int64" n = 100 - m = 4 + buffer_size = 4 tx = te.thread_axis("threadIdx.x") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") C = ib.pointer("float32", name="C") ib.scope_attr(tx, "thread_extent", 1) with ib.for_range(0, n) as i: - B = ib.allocate("float32", m, name="B", scope="shared") + B = ib.allocate("float32", buffer_size, name="B", scope="shared") with ib.new_scope(): ib.scope_attr(B.asobject(), "double_buffer_scope", 1) - with ib.for_range(0, m) as j: + with ib.for_range(0, buffer_size) as j: B[j] = A[i * 4 + j] - with ib.for_range(0, m) as j: + with ib.for_range(0, buffer_size) as j: C[j] = B[j] + 1 stmt = ib.get() @@ -119,7 +119,7 @@ def test_flatten_double_buffer(): stmt = mod["main"].body assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert stmt.body.extent.value == 2 * buffer_size mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..22446f8b6516 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -87,7 +87,7 @@ def test_alloc_seq(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 200 + assert n.extent.value == 200 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -137,7 +137,7 @@ def offset_generater(dtype_list, length): def dtype_test(dtype_list, length): def verify(n): if isinstance(n, tvm.tir.Allocate): - assert n.extents[0].value == offset + assert n.extent.value == offset body = stmt_generater(dtype_list, length) offset = offset_generater(dtype_list, length) @@ -222,7 +222,7 @@ def test_storage_combine(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 16 + assert n.extent.value == 16 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 @@ -527,7 +527,7 @@ def test_inplace_rule3(): # verify inplace folding works def verify(n): if isinstance(n, tvm.tir.Allocate): - assert n.extents[0].value == 70 + assert n.extent.value == 70 tvm.tir.stmt_functor.post_order_visit(stmt, verify) @@ -560,7 +560,7 @@ def test_alloc_seq_type(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 500 + assert n.extent.value == 500 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -595,7 +595,7 @@ def test_alloc_seq_type2(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 200 + assert n.extent.value == 200 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -629,7 +629,7 @@ def test_reuse_small_buffer(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - assert n.extents[0].value == 800 + assert n.extent.value == 800 tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -670,7 +670,7 @@ def compute(a, b): def verify(n): if isinstance(n, tvm.tir.Allocate): - assert n.extents[0].value == 268435456 + assert n.extent.value == 268435456 tvm.tir.stmt_functor.post_order_visit(stmt, verify) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index ffdf4b5916c4..c21ac22862c7 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -69,8 +69,8 @@ def ir(A, B): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", 1) - local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local") - shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared") + local = ib.allocate(A.dtype, 8, name="buf_local", scope="local") + shared = ib.allocate(A.dtype, 8, name="buf_shared", scope="shared") with ib.for_range(0, 8) as i: with ib.if_scope(Aptr[i] < 0): diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index b1e580957b24..5a6e7f682996 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -170,7 +170,7 @@ def test_ir(A, B, C): A = ib.buffer_ptr(A) B = ib.buffer_ptr(B) C = ib.buffer_ptr(C) - i = ib.allocate("int32", (1,), name="i", scope="local") + i = ib.allocate("int32", 1, name="i", scope="local") i[0] = 0 with ib.for_range(0, n) as j: diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 99a22636b927..076d707e361b 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -147,7 +147,7 @@ def test_no_body(): def allocate_with_buffers() -> None: - with T.allocate([1], "float32", "") as [A, B]: # error + with T.allocate(1, "float32", "") as [A, B]: # error T.evaluate(1.0) @@ -384,7 +384,7 @@ def test_match_buffer_shape_mismatch(): def high_dim_store() -> None: with T.block([], "root"): - B = T.allocate([256], "float32", "global") + B = T.allocate(256, "float32", "global") for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 8058b96b024d..020cacff0a79 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -93,7 +93,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - packedB = T.allocate([32768], "float32x32", "global") + packedB = T.allocate(32768, "float32x32", "global") for x in T.parallel(0, 32): for y in T.serial(0, 1024): T.store( @@ -108,7 +108,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: T.broadcast(True, 32), ) for x_outer in T.parallel(0, 32): - C_global = T.allocate([1024], "float32", "global") + C_global = T.allocate(1024, "float32", "global") for y_outer in T.serial(0, 32): for x_c_init in T.serial(0, 32): T.store( @@ -1080,11 +1080,11 @@ def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: ty = T.env_thread("threadIdx.y") tz = T.env_thread("threadIdx.z") T.launch_thread(bz, 196) - Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") - Apad_shared = T.allocate([12288], "float16", "shared") - W_shared = T.allocate([12288], "float16", "shared") - Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") - W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") + Conv_wmma_accumulator = T.allocate(2048, "float32", "wmma.accumulator") + Apad_shared = T.allocate(12288, "float16", "shared") + W_shared = T.allocate(12288, "float16", "shared") + Apad_shared_wmma_matrix_a = T.allocate(512, "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = T.allocate(1024, "float16", "wmma.matrix_b") T.launch_thread(bx, 2) T.launch_thread(by, 4) T.launch_thread(ty, 4) @@ -2653,7 +2653,7 @@ def vthread_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") + B = T.allocate(16, "float32", "local") for j in range(16): B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + T.float32(1) for j in range(16): @@ -3067,7 +3067,7 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - tensor_2 = T.allocate([200704], "uint8", "global", annotations={"attr1_key": "attr1_value"}) + tensor_2 = T.allocate(200704, "uint8", "global", annotations={"attr1_key": "attr1_value"}) for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 383841f19e34..40672b24516b 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -172,7 +172,7 @@ def _post_order(op): ), op.body, ) - alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) + alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extent, op.condition, let_stmt) del rw_info[buffer_var] return alloc if isinstance(op, tvm.tir.Load): @@ -226,7 +226,7 @@ def _merge_block(slist, body): if op.body == body: body = op elif isinstance(op, tvm.tir.Allocate): - body = tvm.tir.Allocate(op.buffer_var, op.dtype, op.extents, op.condition, body) + body = tvm.tir.Allocate(op.buffer_var, op.dtype, op.extent, op.condition, body) elif isinstance(op, tvm.tir.AttrStmt): body = tvm.tir.AttrStmt(op.node, op.attr_key, op.value, body) elif isinstance(op, tvm.tir.For): From bc85435950683a9a1bc151f6b725bdea18072d61 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 30 Sep 2021 14:26:25 -0500 Subject: [PATCH 6/9] [TIR] Added Simplify() to AllocateNode::constant_allocation_size() Some of the updated unit tests express allocation size as a product rather than as a single integer, for readability, but those products should still be treated as constant integers when determining the workspace size needed. --- src/tir/ir/stmt.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index bf37d34ddaef..f2f46e32af3b 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -358,7 +358,11 @@ Allocate::Allocate(Var buffer_var, DataType dtype, PrimExpr extent, PrimExpr con } int32_t AllocateNode::constant_allocation_size(const PrimExpr& extent) { - if (const IntImmNode* int_size = extent.as()) { + arith::Analyzer analyzer; + + PrimExpr simplified = analyzer.Simplify(extent); + + if (const IntImmNode* int_size = simplified.as()) { return int_size->value; } else { return 0; From 0773744ec56a838b85678d9a3314cfe75d046f52 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Oct 2021 12:37:13 -0500 Subject: [PATCH 7/9] [TIR] Updated UnrollLoop to count BufferStore If UnrollLoop is applied before StorageFlatten, there may be BufferLoad/BufferStore nodes that haven't yet been lowered to Load/Store nodes. These should be counted equivalently to Store nodes. --- src/tir/transforms/unroll_loop.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index c6e0b5c5f41e..9574628d3773 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -133,6 +133,11 @@ class LoopUnroller : public StmtExprMutator { } } + Stmt VisitStmt_(const BufferStoreNode* op) final { + ++step_count_; + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const StoreNode* op) final { ++step_count_; return StmtExprMutator::VisitStmt_(op); From 298a9889787d11fe0674a5805eebdcd50cfa28f9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Oct 2021 12:43:01 -0500 Subject: [PATCH 8/9] [UnitTest] Split UnrollLoop tests into separate tests for each behavior. These changes were not necessary for the UnrollLoop change in the previous commit to function, but were made as part of that investigation. They are kept as a separate commit in the PR for ease of review. --- .../test_tir_transform_unroll_loop.py | 72 ++++++++++++------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index b511118f8b52..4989742dcec7 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -14,13 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import pytest + import tvm from tvm import te -import os -def test_unroll_loop(): +@tvm.testing.fixture +def loop_module(): ib = tvm.tir.ir_builder.create() + dtype = "int64" n = te.size_var("n") Ab = tvm.tir.decl_buffer((n,), dtype) @@ -31,41 +37,58 @@ def test_unroll_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) - assert isinstance(stmt, tvm.tir.For) + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) + +def test_auto_unroll_disabled_above_limit(loop_module): with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert not isinstance(ret, tvm.tir.For) + mod = tvm.tir.transform.UnrollLoop()(loop_module) + body = mod["main"].body + assert not isinstance(body, tvm.tir.For) + +def test_auto_unroll_enabled_below_limit(loop_module): with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 15}}): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret, tvm.tir.For) + mod = tvm.tir.transform.UnrollLoop()(loop_module) + body = mod["main"].body + assert isinstance(body, tvm.tir.For) + +def test_explicit_unroll(loop_module): with tvm.transform.PassContext( config={"tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}} ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret, tvm.tir.For) - assert ret.kind == tvm.tir.ForKind.UNROLLED + mod = tvm.tir.transform.UnrollLoop()(loop_module) + body = mod["main"].body + assert isinstance(body, tvm.tir.For) + assert body.kind == tvm.tir.ForKind.UNROLLED + +@tvm.testing.fixture +def loop_module_pragma_sequential(loop_module): + orig_body = loop_module["main"].body ib = tvm.tir.ir_builder.create() ib.scope_attr(tvm.tir.const(0, "int32"), "pragma_auto_unroll_max_step", 16) - ib.emit(stmt) + ib.emit(orig_body) wrapped = ib.get() - wrapped = tvm.tir.SeqStmt([wrapped, stmt]) - assert isinstance(ret, tvm.tir.For) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped)) + body = tvm.tir.SeqStmt([wrapped, orig_body]) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc(loop_module["main"].params, body)) + return mod + + +def test_pragma_unroll(loop_module_pragma_sequential): with tvm.transform.PassContext( config={"tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}} ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret[0], tvm.tir.For) - assert ret[0].kind == tvm.tir.ForKind.UNROLLED - assert isinstance(ret[1], tvm.tir.For) - assert ret[1].kind != tvm.tir.ForKind.UNROLLED + mod = tvm.tir.transform.UnrollLoop()(loop_module_pragma_sequential) + body = mod["main"].body + assert isinstance(body[0], tvm.tir.For) + assert body[0].kind == tvm.tir.ForKind.UNROLLED + assert isinstance(body[1], tvm.tir.For) + assert body[1].kind != tvm.tir.ForKind.UNROLLED def test_unroll_fake_loop(): @@ -89,8 +112,9 @@ def test_unroll_fake_loop(): "tir.UnrollLoop": {"auto_max_depth": 8, "auto_max_extent": 1, "explicit_unroll": False} } ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret[0], tvm.tir.Store) + mod = tvm.tir.transform.UnrollLoop()(mod) + body = mod["main"].body + assert isinstance(body[0], tvm.tir.BufferStore) def test_unroll_single_count_loops(): @@ -111,6 +135,4 @@ def test_unroll_single_count_loops(): if __name__ == "__main__": - test_unroll_loop() - test_unroll_fake_loop() - test_unroll_single_count_loops() + sys.exit(pytest.main(sys.argv)) From 9b1d6e0bca967f50c79fa42933207d9a42c679c0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Oct 2021 14:21:24 -0500 Subject: [PATCH 9/9] [CppTest] Updated IRF.StmtMutator Previous test of CopyOnWrite semantics during StmtMutator relied on AllocateNode having an array parameter. Rewrote these tests to check the same behavior, but implemented over SeqStmtNode instead. --- tests/cpp/ir_functor_test.cc | 60 ++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index d065b65c9cd1..59d72b359f2d 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -169,7 +169,7 @@ TEST(IRF, StmtVisitor) { Stmt body = Evaluate(z); DataType dtype = DataType::Float(32); Var buffer("b", PointerType(PrimType(dtype))); - return Allocate(buffer, dtype, z, const_true(), body); + return Allocate(buffer, dtype, z * z, const_true(), body); }; v(fmaketest()); ICHECK_EQ(v.count, 3); @@ -218,6 +218,14 @@ TEST(IRF, StmtMutator) { return Allocate(buffer, dtype, z, const_true(), body); }; + auto fmakealloc_seq_body = [&]() { + auto z = x + 1; + Stmt body = Evaluate(z); + DataType dtype = DataType::Float(32); + Var buffer("b", PointerType(PrimType(dtype))); + return Allocate(buffer, dtype, z, const_true(), SeqStmt({body, body, body})); + }; + auto fmakeif = [&]() { auto z = x + 1; Stmt body = Evaluate(z); @@ -225,23 +233,38 @@ TEST(IRF, StmtMutator) { }; MyVisitor v; + { - auto body = fmakealloc(); - Stmt body2 = Evaluate(1); - Stmt bref = body.as()->body; - auto* extentptr = body.as()->extent.get(); - Array arr{std::move(body), body2, body2}; - auto* arrptr = arr.get(); - arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr.get() == arrptr); - // inplace update body - ICHECK(arr[0].as()->extent.same_as(x)); - ICHECK(arr[0].as()->extent.get() == extentptr); - // copy because there is additional refs - ICHECK(!arr[0].as()->body.same_as(bref)); - ICHECK(arr[0].as()->body.as()->value.same_as(x)); - ICHECK(bref.as()->value.as()); + // Inplace update of a CopyOnWrite body if there are no additional references. + auto before = fmakealloc_seq_body(); + const AllocateNode* alloc_ptr = before.as(); + const SeqStmtNode* before_body_ptr = before.as()->body.as(); + auto after = v(std::move(before)); + + // We get the same AllocateNode, and the same SeqStmt inside it. + ICHECK_EQ(after.get(), alloc_ptr); + auto after_body_ptr = after.as()->body.as(); + ICHECK_EQ(after_body_ptr, before_body_ptr); + // Verify that the change did actually happen. + ICHECK(after_body_ptr->seq[0].as()->value.same_as(x)); + } + + { + // Copy a CopyOnWrite body if there are additional references. + auto before = fmakealloc_seq_body(); + auto extra_ref = before.as()->body; + const AllocateNode* alloc_ptr = before.as(); + const SeqStmtNode* before_body_ptr = before.as()->body.as(); + auto after = v(std::move(before)); + + // We get the same AllocateNode, but a different SeqStmt inside it. + ICHECK_EQ(after.get(), alloc_ptr); + auto after_body_ptr = after.as()->body.as(); + ICHECK_NE(after_body_ptr, before_body_ptr); + // Verify that the change did actually happen. + ICHECK(after_body_ptr->seq[0].as()->value.same_as(x)); } + { Array arr{fmakealloc()}; // mutate array get reference by another one, triiger copy. @@ -265,7 +288,6 @@ TEST(IRF, StmtMutator) { arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr2.get() == arr.get()); } - { auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1})); @@ -274,9 +296,9 @@ TEST(IRF, StmtMutator) { } { Stmt body = fmakealloc(); + auto* ref1 = body.get(); Stmt body2 = Evaluate(1); auto* ref2 = body2.get(); - auto* extentptr = body.as()->extent.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); body = SeqStmt({body, body2}); @@ -284,7 +306,7 @@ TEST(IRF, StmtMutator) { body = v(std::move(body)); // the seq get flattened ICHECK(body.as()->size() == 3); - ICHECK(body.as()->seq[0].as()->extent.get() == extentptr); + ICHECK(body.as()->seq[0].get() == ref1); ICHECK(body.as()->seq[1].get() == ref2); }