From cddc1dee8a351e9cb00401fa14a13f33468d3ddb Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 20 Aug 2020 12:13:28 -0700 Subject: [PATCH] [TIR] Enforce buffer pointer var type to be consistent with dtype. Now that we have type_annotation in tir::Var. We should make sure that the type annotation to be consistent with the dtype in Buffer declaration and Allocation. This change allows future passes to directly use the content type information via type_annotation. This PR turns on the enforcement on Buffer and also fixed a few cases for Allocate. A follow up PR need to fix a few more cases in the hybrid script parsing before everything can be made consistent. --- include/tvm/tir/op.h | 17 +++ python/tvm/tir/buffer.py | 4 +- python/tvm/tir/ir_builder.py | 4 +- src/driver/driver_api.cc | 2 +- src/tir/ir/buffer.cc | 5 + src/tir/ir/stmt.cc | 3 + src/tir/transforms/bf16_legalize.cc | 157 ++++++++++++++------------ src/tir/transforms/storage_flatten.cc | 6 +- 8 files changed, 115 insertions(+), 83 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 68ca2663ede9..93a54b044fba 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot); TVM_DECLARE_INTRIN_BINARY(ldexp); namespace tir { + +/*! + * \brief Check if type is a pointer to a runtime element type. + * \param type The type to be checked. + * \param element_type The corresponding element type. + * \return The check results + */ +inline bool IsPointerType(const Type& type, const DataType& element_type) { + if (!type.defined()) return false; + if (const auto* ptr_type = type.as()) { + if (const auto* prim_type = ptr_type->element_type.as()) { + return prim_type->dtype == element_type; + } + } + return false; +} + /*! * \brief Make a const value with certain data type. * \param t The target type. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 11bfb4c55921..bd7672a52d9a 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -20,7 +20,7 @@ from tvm._ffi.base import string_types from tvm.runtime import Object, convert -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, PointerType, PrimType from . import _ffi_api @@ -241,7 +241,7 @@ def decl_buffer(shape, shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" elem_offset = Var('%s_elem_offset' % name, shape_dtype) if data is None: - data = Var(name, "handle") + data = Var(name, PointerType(PrimType(dtype))) return _ffi_api.Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor, buffer_type) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 20180d1be45d..b313e58a03af 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container +from tvm.ir import container as _container, PointerType, PrimType from . import stmt as _stmt from . import expr as _expr @@ -325,7 +325,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 142bdfc70dce..14aa4fc56e2e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) { tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact) { - auto data = tir::Var(name, DataType::Handle()); + auto data = tir::Var(name, PointerType(PrimType(dtype))); bool has_any = false; if (!compact) { for (const auto& it : shape) { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 00e3335633ec..d33f2ddf698a 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, String scope, int data_alignment, int offset_factor, BufferType buffer_type) { + CHECK(IsPointerType(data->type_annotation, dtype)) + << "Buffer data field expect to have the right pointer type annotation" + << " annotation=" << data->type_annotation << ", dtype=" << dtype; + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; + n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 296f49207cce..d9e1df46e8fa 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body) { + // TODO(tvm-team): Add invariant check to make sure + // IsPointerPType(buffer_var->type_annotation, dtype) + // once we fix the allocate hybrid script printing. for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); CHECK(extents[i].dtype().is_scalar()); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 4a44b85684b2..97c96edc6ca7 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) { * Lower cast between bf16 and fp32 * Lower bf16 FloatImm to int16 */ -class BF16LowerRewriter : StmtExprMutator { +class BF16LowerRewriter : public StmtExprMutator { public: BF16LowerRewriter() {} - std::unordered_map buffer_remap; - std::unordered_map var_remap; - - Stmt operator()(Stmt s) { return VisitStmt(s); } + using StmtExprMutator::operator(); PrimExpr VisitExpr_(const CastNode* op) final { auto op_val = StmtExprMutator::VisitExpr(op->value); @@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}); - } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); @@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - auto itr = var_remap.find(op); - if (itr != var_remap.end()) { + Var var = GetRef(op); + + auto itr = var_remap_.find(var); + if (itr != var_remap_.end()) { return itr->second; + } else { + return std::move(var); } - if (op->dtype.is_bfloat16()) { - CHECK(!op->type_annotation.defined()); - auto ret = Var(op->name_hint, op->dtype); - var_remap[op] = ret; - return std::move(ret); - } - return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const AllocateNode* op) final { - Stmt node_holder; - const AllocateNode* newop; if (op->dtype.is_bfloat16()) { - auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents, - op->condition, op->body); - node_holder = v; - newop = static_cast(v.operator->()); + DataType dtype = DataType::UInt(16, op->dtype.lanes()); + Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); + var_remap_[op->buffer_var] = buffer_var; + return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); } else { - newop = op; + return StmtExprMutator::VisitStmt_(op); } - return StmtExprMutator::VisitStmt_(newop); } Stmt VisitStmt_(const BufferStoreNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferStoreNode* newop; - BufferStore newop_holder; - if (itr != buffer_remap.end()) { - newop_holder = BufferStore(itr->second, op->value, op->indices); - newop = newop_holder.operator->(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferStore(it->second, op->value, op->indices); } else { - newop = op; + return ret; } - return StmtExprMutator::VisitStmt_(newop); } Stmt VisitStmt_(const AttrStmtNode* op) final { - const AttrStmtNode* newop = op; - Stmt newop_holder; - if (auto buffer = op->node.as()) { - auto itr = buffer_remap.find(buffer); - if (itr != buffer_remap.end()) { - newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); - newop = newop_holder.as(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (auto* buffer = op->node.as()) { + auto it = buffer_remap_.find(GetRef(buffer)); + if (it != buffer_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto buffer = op->node.as()) { - auto itr = var_remap.find(buffer); - if (itr != var_remap.end()) { - newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); - newop = newop_holder.as(); + } else if (auto* var = op->node.as()) { + auto it = var_remap_.find(GetRef(var)); + if (it != var_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } } - return StmtExprMutator::VisitStmt_(newop); + return ret; } Stmt VisitStmt_(const BufferRealizeNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferRealizeNode* newop; - Stmt newop_holder; - if (itr != buffer_remap.end()) { - auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body); - newop_holder = v; - newop = v.operator->(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferRealize(it->second, op->bounds, op->condition, op->body); } else { - newop = op; + return ret; + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = var_remap_.find(op->buffer_var); + if (it != var_remap_.end()) { + return Store(it->second, op->value, op->index, op->predicate); + } else { + return ret; } - return StmtExprMutator::VisitStmt_(newop); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferLoadNode* newop; - BufferLoad newop_holder; - if (itr != buffer_remap.end()) { - newop_holder = BufferLoad(itr->second, op->indices); - newop = newop_holder.operator->(); + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferLoad(it->second, op->indices); } else { - newop = op; + return ret; } - return StmtExprMutator::VisitExpr_(newop); } PrimExpr VisitExpr_(const LoadNode* op) final { - bool is_bf16 = false; + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + if (op->dtype.is_bfloat16()) { - is_bf16 = true; - } - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { - return GetRef(op); + auto it = var_remap_.find(op->buffer_var); + CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped"; + return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate); } else { - return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var, - index, predicate); + return ret; } } @@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator { void AlterBuffers(PrimFuncNode* op) { std::vector> changes; + for (auto& itr : op->buffer_map) { auto oldbuf = itr.second; if (oldbuf->dtype.is_bfloat16()) { - auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, - oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope, - oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type); - buffer_remap[oldbuf.operator->()] = newbuf; + DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); + Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); + auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, + oldbuf->name, oldbuf->scope, oldbuf->data_alignment, + oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap_[oldbuf] = newbuf; + var_remap_[oldbuf->data] = buffer_var; changes.emplace_back(itr.first, newbuf); + } else { + changes.emplace_back(itr); } } - if (buffer_remap.size() != 0) { + + if (buffer_remap_.size() != 0) { op->buffer_map = Map(changes.begin(), changes.end()); } } + + private: + std::unordered_map buffer_remap_; + std::unordered_map var_remap_; }; namespace transform { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 8eb43f8ebc84..7475bf6d2f8e 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = - Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape, - strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); + e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), + op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body);