diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3ddf4af12bea..8b7c0e6c9b93 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1009,7 +1009,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations."; - PrimExpr index = load->indices[0]; + PrimExpr index = analyzer_->Simplify(load->buffer->elem_offset + load->indices[0]); if (const RampNode* r = index.as()) { index = r->base; } @@ -1253,15 +1253,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + ICHECK(deep_equal_(it->second, op->value)) << "Let cannot bind the same var to two different values"; } else { - let_binding_[op->var] = op; + let_binding_[op->var] = op->value; } auto var_value = MakeValue(op->value); var_map_[op->var.get()] = var_value; var_value->setName(op->var->name_hint.c_str()); - analyzer_->Bind(op->var, op->value); return MakeValue(op->body); } @@ -1301,6 +1300,7 @@ void CodeGenLLVM::BufferAccessHelper( } PrimExpr last_index = indices[indices.size() - 1]; + ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes()); bool is_volatile = volatile_buf_.count(buffer->data.get()); @@ -1355,8 +1355,15 @@ void CodeGenLLVM::BufferAccessHelper( std::vector all_index_values = earlier_index_values; all_index_values.push_back(last_index_value); + llvm::Value* buffer_data_begin = MakeValue(buffer->data); + if (!analyzer_->CanProveEqual(buffer->elem_offset, 0)) { + buffer_data_begin = CreateBufferPtr(buffer_data_begin, buffer_element_dtype, + {MakeValue(buffer->elem_offset)}, buffer_element_dtype) + .addr; + } + TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, + CreateBufferPtr(buffer_data_begin, buffer_element_dtype, all_index_values, value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index); @@ -1632,7 +1639,6 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { llvm::Value* value = MakeValue(op->value); value->setName(v->name_hint.c_str()); var_map_[v] = value; - analyzer_->Bind(op->var, op->value); if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), alloc_storage_info_[v].alignment); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 559ce97f8fc4..052eb4f1a15f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -439,7 +439,7 @@ class CodeGenLLVM : public ExprFunctor, // deep comparison of PrimExpr ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value - std::unordered_map let_binding_; + std::unordered_map let_binding_; // Cache potential common path ops to slightly improve lookup time. // global symbol table. OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 3ad7882d792c..b764678defc3 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -577,7 +577,8 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations."; - os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))"; + PrimExpr index = analyzer_.Simplify(load->buffer->elem_offset + load->indices[0]); + os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), index) << "))"; } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3U); os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); @@ -669,7 +670,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; DataType value_dtype = op->dtype; - PrimExpr index = op->indices[0]; + PrimExpr index = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]); Var buffer_var = op->buffer->data; DataType element_dtype = op->buffer->dtype; @@ -684,7 +685,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { const RampNode* ramp = index.as(); ICHECK(ramp); - arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); + arith::ModularSet me = analyzer_.modular_set(ramp->base); // The condition: {k * coeff + base} divisible by the alignment for any k if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) { can_vector_load = true; @@ -733,7 +734,7 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; - PrimExpr index_expr = op->indices[0]; + PrimExpr index_expr = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]); Var buffer_var = op->buffer->data; if (value_dtype.lanes() == element_dtype.lanes()) { @@ -786,10 +787,10 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + ICHECK(deep_equal_(it->second, op->value)) << "Let cannot bind the same var to two different values"; } else { - let_binding_[op->var] = op; + let_binding_[op->var] = op->value; } std::string value = PrintExpr(op->value); var_idmap_[op->var.get()] = value; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 4f671950260e..5ade46feaa93 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ +#include #include #include #include @@ -280,7 +281,9 @@ class CodeGenC : public ExprFunctor, // deep comparison of PrimExpr ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value - std::unordered_map let_binding_; + std::unordered_map let_binding_; + // analyzer + arith::Analyzer analyzer_; }; } // namespace codegen diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index a0e19ca35cd9..72715c2efd15 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -382,7 +382,9 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { // Overload tvm_address_of to add storage scope (e.g. __global). const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); - ICHECK_EQ(load->indices.size(), 0) << "CodeGenOpenCL only supports flat memory allocations."; + ICHECK_EQ(load->indices.size(), 1) << "CodeGenOpenCL only supports flat memory allocations."; + PrimExpr index = analyzer_.Simplify(load->buffer->elem_offset + load->indices[0]); + os << "(("; auto it = alloc_storage_scope_.find(load->buffer->data.get()); if (it != alloc_storage_scope_.end()) { @@ -390,7 +392,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } this->PrintType(load->dtype.element_of(), os); os << " *)" << this->GetVarID(load->buffer->data.get()) << " + "; - this->PrintExpr(load->indices[0], os); + this->PrintExpr(index, os); os << ')'; } else if (op->op.same_as(builtin::texture2d_store())) { auto* ptr_type = op->args[0].as()->type_annotation.as(); diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 3508eef43185..19293a8ba7c4 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_ #define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_ +#include #include #include @@ -85,6 +86,8 @@ class CodeGenOpenCL final : public CodeGenC { // Mapping from buffer to allocation size. // Useful to track when a scalar store of a vectorized texture load is required. std::unordered_map allocation_size_; + // analyzer + arith::Analyzer analyzer_; }; } // namespace codegen diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 4f875e955576..4b8a89edfcc2 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -287,13 +287,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + ICHECK(deep_equal_(it->second, op->value)) << "Let cannot bind the same var to two different values"; } else { - let_binding_[op->var] = op; + let_binding_[op->var] = op->value; } var_map_[op->var.get()] = MakeValue(op->value); - analyzer_->Bind(op->var, op->value); return MakeValue(op->body); } @@ -428,7 +427,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; Var buffer_var = op->buffer->data; - PrimExpr prim_index = op->indices[0]; + PrimExpr prim_index = analyzer_->Simplify(op->buffer->elem_offset + op->indices[0]); DataType desired_read_type = op->dtype; if (desired_read_type == DataType::Bool()) { @@ -501,7 +500,7 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::functionindices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; Var buffer_var = op->buffer->data; - PrimExpr prim_index = op->indices[0]; + PrimExpr prim_index = analyzer_->Simplify(op->buffer->elem_offset + op->indices[0]); auto it = storage_info_.find(buffer_var.get()); ICHECK(it != storage_info_.end()); @@ -716,7 +715,6 @@ void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) { ICHECK(!var_map_.count(op->var.get())); ICHECK(!op->var.dtype().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); - analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 08b9db0ee539..1d0841bbe02c 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -213,7 +213,7 @@ class CodeGenSPIRV : public ExprFunctor, ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value - std::unordered_map let_binding_; + std::unordered_map let_binding_; // Running total of the number of bytes of shared memory used. // Checked against the max_shared_memory_per_group diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index e70405445349..0c5fa45baeb9 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -147,7 +147,7 @@ void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) { ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " << "Has StorageFlatten (TE-based schedules) or " << "FlattenBuffer (TIR-based schedules) been run?"; - auto index = op->indices[0]; + auto index = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]); this->Push(op->buffer->data); StackVM::OpCode code = StackVM::GetLoad(op->dtype); @@ -170,7 +170,7 @@ void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " << "Has StorageFlatten (TE-based schedules) or " << "FlattenBuffer (TIR-based schedules) been run?"; - auto index = op->indices[0]; + auto index = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]); this->Push(op->buffer->data); StackVM::OpCode code = StackVM::GetStore(op->value.dtype()); @@ -195,10 +195,11 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); - ICHECK_EQ(load->indices.size(), 0) << "CodeGenStackVM only supports flat memory allocations."; + ICHECK_EQ(load->indices.size(), 1) << "CodeGenStackVM only supports flat memory allocations."; + PrimExpr index = analyzer_.Simplify(load->buffer->elem_offset + load->indices[0]); this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get())); - this->Push(load->indices[0]); + this->Push(index); this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index ae6f316b475d..d3161b15fb08 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ #define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ +#include #include #include #include @@ -156,6 +157,8 @@ class CodeGenStackVM : public ExprFunctor, std::unordered_map str_idmap_; /*! \brief id of each global function */ std::unordered_map extern_fun_idmap_; + /*! \brief analyzer */ + arith::Analyzer analyzer_; Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace"); Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace"); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index c2bf27393173..e7265be957ea 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -167,6 +167,16 @@ class DataTypeVisitor final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const BufferLoadNode* op) { + VisitExpr(op->buffer->elem_offset); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) { + VisitExpr(op->buffer->elem_offset); + StmtExprVisitor::VisitStmt_(op); + } + // the narrowed datatype of Var and IntImm std::unordered_map vmap; @@ -217,11 +227,13 @@ class DataTypeRewriter : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = GetRef(op); + auto buffer = VisitBuffer(op->buffer); auto value = this->VisitExpr(op->value); auto indices = VisitIndices(op->indices); - if (!value.same_as(op->value) || !indices.same_as(op->indices)) { + if (!buffer.same_as(op->buffer) || !value.same_as(op->value) || !indices.same_as(op->indices)) { auto writer = store.CopyOnWrite(); + writer->buffer = buffer; writer->value = value; writer->indices = indices; } @@ -232,10 +244,12 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad load = GetRef(op); + auto buffer = VisitBuffer(op->buffer); auto indices = VisitIndices(op->indices); - if (!indices.same_as(op->indices)) { + if (!buffer.same_as(op->buffer) || !indices.same_as(op->indices)) { auto writer = load.CopyOnWrite(); + writer->buffer = buffer; writer->indices = indices; } @@ -253,6 +267,23 @@ class DataTypeRewriter : public StmtExprMutator { return indices; } + Buffer VisitBuffer(const Buffer& origin_buffer) { + auto it = buffer_remap_.find(origin_buffer); + if (it != buffer_remap_.end()) { + return it->second; + } + is_index_ = true; + PrimExpr elem_offset = VisitExpr(origin_buffer->elem_offset); + is_index_ = false; + Buffer updated_buffer = origin_buffer; + if (!elem_offset.same_as(origin_buffer->elem_offset)) { + auto n = updated_buffer.CopyOnWrite(); + n->elem_offset = elem_offset; + } + buffer_remap_.insert(it, {origin_buffer, updated_buffer}); + return updated_buffer; + } + Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); @@ -375,6 +406,9 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final; PrimExpr VisitExpr_(const CallNode* op) final; + // buffer remapping dict + std::unordered_map buffer_remap_; + private: // the internal visitor to deduce the narrowed dtype DataTypeVisitor visitor_; @@ -449,7 +483,17 @@ namespace transform { Pass NarrowDataType(int target_bits) { auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = DataTypeRewriter(target_bits)(std::move(n->body)); + DataTypeRewriter rewriter(target_bits); + n->body = rewriter(std::move(n->body)); + for (const Var& param : f->params) { + auto it = n->buffer_map.find(param); + if (it != n->buffer_map.end()) { + auto remap_it = rewriter.buffer_remap_.find((*it).second); + if (remap_it != rewriter.buffer_remap_.end()) { + n->buffer_map.Set(param, remap_it->second); + } + } + } return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0534f31c3423..a10deafcd58f 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -406,37 +406,13 @@ class StoragePlanRewriter : public StmtExprMutator { Node VisitBufferAccess(Node node) { auto it = alloc_map_.find(node->buffer->data.get()); if (it != alloc_map_.end()) { - Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); - - Array indices = node->indices; - indices.Set(indices.size() - 1, - RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); - + Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var, it->second); auto writer = node.CopyOnWrite(); writer->buffer = buf; - writer->indices = indices; } return node; } - Buffer RemapBuffer(Buffer buf, Var new_backing_array) { - auto key = buf.get(); - auto it = buffer_remap_.find(key); - if (it != buffer_remap_.end()) { - ICHECK_EQ(it->second->data.get(), new_backing_array.get()) - << "Cannot remap buffer " << buf->name << " to use backing array " - << new_backing_array->name_hint << ", previously used backing array " - << it->second->data->name_hint; - return it->second; - } - - Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides, - buf->elem_offset, new_backing_array->name_hint, buf->data_alignment, - buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); - buffer_remap_[key] = remapped; - return remapped; - } - Stmt VisitStmt_(const BufferStoreNode* op) final { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); return VisitBufferAccess(std::move(node)); @@ -577,12 +553,30 @@ class StoragePlanRewriter : public StmtExprMutator { } return MergeNest(nest, body); } - // Remap the index - PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { - if (e->bits_offset == 0) return index; - uint64_t elem_bits = dtype.bits(); - ICHECK_EQ(e->bits_offset % elem_bits, 0U); - return make_const(index.dtype(), e->bits_offset / elem_bits) + index; + // Remap the buffer + Buffer RemapBuffer(Buffer buf, Var new_backing_array, const StorageEntry* e) { + auto key = buf.get(); + auto it = buffer_remap_.find(key); + if (it != buffer_remap_.end()) { + ICHECK_EQ(it->second->data.get(), new_backing_array.get()) + << "Cannot remap buffer " << buf->name << " to use backing array " + << new_backing_array->name_hint << ", previously used backing array " + << it->second->data->name_hint; + return it->second; + } + + PrimExpr elem_offset = buf->elem_offset; + if (e->bits_offset != 0) { + uint64_t elem_bits = buf->dtype.bits(); + ICHECK_EQ(e->bits_offset % elem_bits, 0U); + elem_offset += make_const(elem_offset.dtype(), e->bits_offset / elem_bits); + } + + Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides, elem_offset, + new_backing_array->name_hint, buf->data_alignment, buf->offset_factor, + buf->buffer_type, buf->axis_separators, buf->span); + buffer_remap_[key] = remapped; + return remapped; } // Prepare the new allocations void PrepareNewAlloc() { diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index b73534090ab5..30c92d361823 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -200,7 +200,7 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; - si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, + si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 0, buffer_var_name, 16, 1, BufferType::kDefault)); } if (resource_handle) { diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 3250efc3f71e..d0d4a6341980 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tempfile import tvm from tvm import te from tvm.contrib import nvcc import numpy as np -import time +from tvm.script import tir as T import tvm.testing @@ -301,6 +302,51 @@ def check_device(device): check_device("cuda") +@tvm.testing.requires_llvm +def test_aliased_buffer_on_different_offsets(): + @T.prim_func + def func_with_aliased_buffers(X: T.Buffer[(128,), "int32"], Y: T.Buffer[(2048,), "int8"]): + T.func_attr({"global_symbol": "func_with_aliased_buffers"}) + Y1 = T.buffer_decl([128], "int32", Y.data, elem_offset=0) + Y2 = T.buffer_decl([128], "int32", Y.data, elem_offset=128) + Y3 = T.buffer_decl([128], "int32", Y.data, elem_offset=256) + Y4 = T.buffer_decl([128], "int32") + with T.let(Y4.data, T.address_of(Y3[128], dtype="handle")): + for i in range(128): + Y1[i] = X[i] + 1 + for i in range(128): + Y2[i] = X[i] - 1 + for i in range(128): + Y3[i] = X[i] * 2 + for i in range(128): + Y4[i] = X[i] * (-2) + + def check_target(target, dev): + if target != "c" and not tvm.testing.device_enabled(target): + print("skip because %s is not enabled.." % target) + return + f = tvm.build({target: tvm.IRModule.from_expr(func_with_aliased_buffers)}, target=target) + x = tvm.nd.array((np.random.uniform(-256, 256, 128)).astype("int32"), dev) + y = tvm.nd.array(np.zeros([2048], dtype="int8"), dev) + if target == "c": + with tempfile.NamedTemporaryFile(suffix="temp.so") as libfile: + f.export_library(libfile.name) + f = tvm.runtime.load_module(libfile.name) + f["func_with_aliased_buffers"](x, y) + y1 = y.numpy()[:512].view("int32") + y2 = y.numpy()[512:1024].view("int32") + y3 = y.numpy()[1024:1536].view("int32") + y4 = y.numpy()[1536:].view("int32") + tvm.testing.assert_allclose(x.numpy() + 1, y1, rtol=1e-6) + tvm.testing.assert_allclose(x.numpy() - 1, y2, rtol=1e-6) + tvm.testing.assert_allclose(x.numpy() * 2, y3, rtol=1e-6) + tvm.testing.assert_allclose(x.numpy() * -2, y4, rtol=1e-6) + + check_target("llvm", tvm.cpu()) + check_target("c", tvm.cpu()) + check_target("stackvm", tvm.cpu()) + + if __name__ == "__main__": test_exp() try_warp_memory() @@ -309,3 +355,4 @@ def check_device(device): test_log_pow_llvm() test_popcount() test_fmod() + test_aliased_buffer_on_different_offsets() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 51c382309856..ecd2857bf320 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -18,6 +18,7 @@ from tvm import te, relay from tvm.driver.build_module import schedule_to_module from tvm.tir import const +from tvm.script import tir as T def lower_stmt(params, stmt, target_bits): @@ -277,6 +278,25 @@ def test_ramp_dtype_consistency(): lower_sch(s, [A], 32, extra_passes=[tvm.tir.transform.VectorizeLoop()]) +def test_narrow_buffer_elemoffset(): + @T.prim_func + def func_before(): + X = T.buffer_decl([128], elem_offset=T.int64(1000)) + Y = T.buffer_decl([128], elem_offset=T.int64(100)) + for i in range(128): + Y[i] = X[i] + 1.0 + + @T.prim_func + def func_after(): + X = T.buffer_decl([128], elem_offset=T.int32(1000)) + Y = T.buffer_decl([128], elem_offset=T.int32(100)) + for i in range(128): + Y[i] = X[i] + 1.0 + + mod = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(func_before)) + tvm.ir.assert_structural_equal(mod["main"], func_after) + + if __name__ == "__main__": test_basic() test_thread_axis() @@ -286,3 +306,4 @@ def test_ramp_dtype_consistency(): test_relay_basic() test_relay_take() test_ramp_dtype_consistency() + test_narrow_buffer_elemoffset() diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 5a91788283d6..2131205e62eb 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -17,6 +17,7 @@ import tvm from tvm import te from tvm.driver.build_module import schedule_to_module +from tvm.script import tir as T def test_storage_share(): @@ -185,35 +186,44 @@ def verify(n): def test_storage_combine(): - n = 8 - A = te.placeholder((4,), name="A") - num_stage = 5 - B = A - stages = [] - for t in range(num_stage): - B = te.compute((n,), lambda i: B[i] + B[0] + (t + 1), name="A%d" % t) - stages.append(B) - - s = te.create_schedule(B.op) - for S in stages[:-1]: - s[S].set_scope("global:tag") - - mod = schedule_to_module(s, [A, B]) - mod = tvm.tir.transform.StorageFlatten(64)(mod) - - mod = tvm.tir.transform.Simplify()(mod) - mod = tvm.tir.transform.StorageRewrite()(mod) - stmt = mod["main"].body - - num_alloc = [0] - - def verify(n): - if isinstance(n, tvm.tir.Allocate): - num_alloc[0] += 1 - assert n.extents[0].value == 16 - - tvm.tir.stmt_functor.post_order_visit(stmt, verify) - assert num_alloc[0] == 1 + @T.prim_func + def before_rewrite(A: T.Buffer[(4,), "float32"], A4: T.Buffer[(8,), "float32"]) -> None: + A0 = T.allocate([8], "float32", "global:tag") + for i in T.serial(8): + A0[i] = A[i] + A[0] + T.float32(1) + A1 = T.allocate([8], "float32", "global:tag") + for i in T.serial(8): + A1[i] = A0[i] + A0[0] + T.float32(2) + A2 = T.allocate([8], "float32", "global:tag") + for i in T.serial(8): + A2[i] = A1[i] + A1[0] + T.float32(3) + A3 = T.allocate([8], "float32", "global:tag") + for i in T.serial(8): + A3[i] = A2[i] + A2[0] + T.float32(4) + for i in T.serial(8): + A4[i] = A3[i] + A3[0] + T.float32(5) + + @T.prim_func + def after_rewrite(A: T.Buffer[(4,), "float32"], A4: T.Buffer[(8,), "float32"]) -> None: + A0 = T.allocate([16], "float32", "global:tag") + A0_1 = T.buffer_decl([8], dtype="float32", data=A0.data, scope="global:tag") + A0_2 = T.buffer_decl([8], dtype="float32", data=A0.data, elem_offset=8, scope="global:tag") + A0_3 = T.buffer_decl([8], dtype="float32", data=A0.data, scope="global:tag") + A0_4 = T.buffer_decl([8], dtype="float32", data=A0.data, elem_offset=8, scope="global:tag") + for i in T.serial(8): + A0_1[i] = A[i] + A[0] + T.float32(1) + for i in T.serial(8): + A0_2[i] = A0_1[i] + A0_1[0] + T.float32(2) + for i in T.serial(8): + A0_3[i] = A0_2[i] + A0_2[0] + T.float32(3) + for i in T.serial(8): + A0_4[i] = A0_3[i] + A0_3[0] + T.float32(4) + for i in T.serial(8): + A4[i] = A0_4[i] + A0_4[0] + T.float32(5) + + mod = tvm.tir.transform.StorageRewrite()(tvm.IRModule.from_expr(before_rewrite)) + mod = tvm.tir.transform.RemoveNoOp()(mod) + tvm.ir.assert_structural_equal(mod["main"], after_rewrite) def test_storage_combine_with_vectorization(): diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 4ed02615cd44..99fff94d9c73 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -141,8 +141,8 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: class LinearStructurePlanned: @T.prim_func def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: - fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -156,8 +156,8 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -174,8 +174,8 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -186,8 +186,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -371,7 +371,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -383,7 +383,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -406,7 +406,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -429,7 +429,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -451,7 +451,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): @@ -469,7 +469,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla @T.prim_func def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output: T.handle) -> None: - global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 1e8247c6e135..b6be993a4d16 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -948,7 +948,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): nest_size += 1 # Get the src/dst arguments dst_var = loop_body.buffer.data - dst_idx = loop_body.indices[0] + dst_idx = loop_body.buffer.elem_offset + loop_body.indices[0] # Derive loop variables and extents tmp_body = stmt.body indices = [] @@ -1008,19 +1008,23 @@ def _flatten_loop(src_coeff, dst_coeff, extents): imm_val = None if isinstance(rhs, tvm.tir.IntImm): assert lhs.buffer.data.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(lhs.indices[0], indices) + lhs_index = lhs.buffer.elem_offset + lhs.indices[0] + src_coeff = tvm.arith.detect_linear_equation(lhs_index, indices) use_imm = True imm_val = rhs if isinstance(lhs, tvm.tir.IntImm): assert rhs.buffer.data.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(rhs.indices[0], indices) + rhs_index = rhs.buffer.elem_offset + rhs.indices[0] + src_coeff = tvm.arith.detect_linear_equation(rhs_index, indices) use_imm = True imm_val = lhs if imm_val is None: imm_val = 0 + lhs_index = lhs.buffer.elem_offset + lhs.indices[0] + rhs_index = rhs.buffer.elem_offset + rhs.indices[0] assert lhs.buffer.data.same_as(dst_var) and rhs.buffer.data.same_as(dst_var) - src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.indices[0], indices) - src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.indices[0], indices) + src_lhs_coeff = tvm.arith.detect_linear_equation(lhs_index, indices) + src_rhs_coeff = tvm.arith.detect_linear_equation(rhs_index, indices) # Determine which side has the same coefficients lhs_equal = True rhs_equal = True