From 7a4e763c314cdb830cf3c0534790d3f4fd4c1f02 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Fri, 24 Jun 2022 12:32:52 +0400 Subject: [PATCH 1/2] [Relay] Handle memory scope during lowering from relay level Relay expressions can have assigned virtual devices with certain memory scope. This change landing of memory scope information from Relay level to tir --- include/tvm/driver/driver_api.h | 17 +++++++++++++++ include/tvm/relay/expr.h | 3 ++- src/driver/driver_api.cc | 5 +++-- src/relay/backend/te_compiler.cc | 29 +++++++++++++++++++++++++- src/relay/backend/te_compiler_cache.cc | 3 ++- src/relay/backend/te_compiler_cache.h | 7 ++++++- 6 files changed, 58 insertions(+), 6 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 45a938247cc8..1428a3f705f8 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -166,6 +166,23 @@ TVM_DLL runtime::Module build(const Map& input, const Target& */ TVM_DLL runtime::Module build(const Map& input, const Target& target_host); +/*! + * \brief Creates TIR Buffer for provided parameters + * \param shape shape of the buffer + * \param dtype data type + * \param name buffer name + * \param data_alignment alignment requirement of data pointer in bytes + * \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be + * multiple of offset_factor + User can specify data_alignment and offset_factor to be 0 + * A default value will be picked. + * \param compact If the statement has already bound to a compact buffer. + * \param memory_scope memory scope of the buffer + */ +TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, + std::string name, int data_alignment, + int offset_factor, bool compact, + std::string memory_scope = ""); } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6b014c8478d8..bd094a7f6905 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -218,7 +218,8 @@ class VarNode : public ExprNode { bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal(type_annotation, other->type_annotation) && equal(vid, other->vid); + return equal(type_annotation, other->type_annotation) && equal(vid, other->vid) && + equal(virtual_device_, other->virtual_device_); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7f015e7ca2b9..af10d9c7252d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -84,9 +84,10 @@ Target DefaultTargetHost(Target target) { } tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, - int data_alignment, int offset_factor, bool compact) { + int data_alignment, int offset_factor, bool compact, + std::string memory_scope) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - auto data = tir::Var(name, PointerType(PrimType(storage_dtype))); + auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); bool has_any = false; if (!compact) { for (const auto& it : shape) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e9491b0a8901..d1682909d9bf 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -414,6 +414,32 @@ class TECompilerImpl : public TECompilerNode { } // lower the function std::unordered_map binds; + + // If we have memory scopes, need to create tir::Buffer knowing this info + size_t i = 0; // for corresponding from tensor array + for (Var param : key->source_func->params) { + if (!param->virtual_device()->memory_scope.empty()) { + for (const auto& ttype : FlattenTupleType(param->checked_type())) { + te::Tensor x_ref = value->cached_func->inputs[i]; + // verification if we have synced params and tensors + ICHECK(ttype->dtype == x_ref->dtype && ttype->shape.size() == x_ref->shape.size()) + << "function parameter does not correspond to prepared tensor"; + binds[x_ref] = + BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, false, + param->virtual_device()->memory_scope); + } + } + i++; + } + if (key->virtual_device != VirtualDevice::FullyUnconstrained() && + !key->virtual_device->memory_scope.empty() && + key->virtual_device->memory_scope != "global") { + ICHECK(value->cached_func->outputs.size() == 1) + << "Expect only one output for defined memory scope"; + te::Tensor x_ref = value->cached_func->outputs[0]; + binds[x_ref] = BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, + false, key->virtual_device->memory_scope); + } auto func_name = value->cached_func->prim_fn_var->name_hint; VLOG(1) << "scheduling"; IRModule scheduled_module = @@ -895,7 +921,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // Cases 1 and 2: lower the primitive function for the desired target, possibly using external // codegen. - CCacheKey key(Downcast(primitive_func), target); + CCacheKey key(Downcast(primitive_func), target, + GetVirtualDevice(GetRef(call_node))); CachedFunc cfunc = compiler_->Lower(key, module_name_); ICHECK(cfunc.defined()); return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 8715900c0c4a..17e3d573053f 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -66,10 +66,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey::CCacheKey(Function source_func, Target target) { +CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->virtual_device = std::move(vd); data_ = std::move(n); } diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 55f221ac8ba0..ac2619826019 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -82,10 +82,13 @@ class CCacheKeyNode : public Object { Function source_func; /*! \brief The hardware target.*/ Target target; + /*! \brief The virtual device constrains.*/ + VirtualDevice virtual_device; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); + v->Visit("virtual_device", &virtual_device); } /*! \return The hash value of CCacheKey. */ inline size_t Hash() const; @@ -117,7 +120,8 @@ class CCacheKey : public ObjectRef { * \param source_func The source function. * \param target The target device. */ - TVM_DLL CCacheKey(Function source_func, Target target); + TVM_DLL CCacheKey(Function source_func, Target target, + VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained()); const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator @@ -244,6 +248,7 @@ inline size_t CCacheKeyNode::Hash() const { inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && + this->virtual_device == other->virtual_device && tvm::StructuralEqual()(this->source_func, other->source_func); } From a46f8c40c08f8b2506d5f4d85477a7a434851ecc Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Wed, 29 Jun 2022 12:10:25 +0400 Subject: [PATCH 2/2] Move BufferWithOffsetAlignment from driver_api to tir buffer --- include/tvm/driver/driver_api.h | 18 ------------------ include/tvm/tir/buffer.h | 17 +++++++++++++++++ src/driver/driver_api.cc | 31 ++----------------------------- src/relay/backend/te_compiler.cc | 9 +++++---- src/tir/ir/buffer.cc | 27 +++++++++++++++++++++++++++ 5 files changed, 51 insertions(+), 51 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 1428a3f705f8..48800b193cb4 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -165,24 +165,6 @@ TVM_DLL runtime::Module build(const Map& input, const Target& * \return The built module that contains code for different processors. */ TVM_DLL runtime::Module build(const Map& input, const Target& target_host); - -/*! - * \brief Creates TIR Buffer for provided parameters - * \param shape shape of the buffer - * \param dtype data type - * \param name buffer name - * \param data_alignment alignment requirement of data pointer in bytes - * \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be - * multiple of offset_factor - User can specify data_alignment and offset_factor to be 0 - * A default value will be picked. - * \param compact If the statement has already bound to a compact buffer. - * \param memory_scope memory scope of the buffer - */ -TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, - std::string name, int data_alignment, - int offset_factor, bool compact, - std::string memory_scope = ""); } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index ca7faf1cdefb..d7a2aec0b972 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -295,6 +295,23 @@ class DataProducer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode); }; +/*! + * \brief Creates TIR Buffer for provided parameters + * \param shape shape of the buffer + * \param dtype data type + * \param name buffer name + * \param data_alignment alignment requirement of data pointer in bytes + * \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be + * multiple of offset_factor + User can specify data_alignment and offset_factor to be 0 + * A default value will be picked. + * \param compact If the statement has already bound to a compact buffer. + * \param memory_scope memory scope of the buffer + */ +TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, + std::string name, int data_alignment, + int offset_factor, bool compact, + std::string memory_scope = ""); } // namespace tir } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index af10d9c7252d..0446347eca2c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -83,33 +83,6 @@ Target DefaultTargetHost(Target target) { } } -tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, - int data_alignment, int offset_factor, bool compact, - std::string memory_scope) { - DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); - bool has_any = false; - if (!compact) { - for (const auto& it : shape) { - if (it.as()) { - has_any = true; - break; - } - } - } - tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; - - PrimExpr elem_offset; - if (offset_factor != 0) { - elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); - } else { - elem_offset = PrimExpr(); - } - - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, - offset_factor, buffer_type); -} - void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { @@ -119,8 +92,8 @@ void GetBinds(const Array& args, bool compact, if (const te::TensorNode* tensor_node = x.as()) { te::Tensor x_ref = GetRef(tensor_node); if (out_binds->find(x_ref) == out_binds->end()) { - tir::Buffer buf = - BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, + x_ref->op->name, -1, 0, compact); out_binds->Set(x_ref, buf); out_arg_list->push_back(buf); } else { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index d1682909d9bf..08fa18b61e16 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -425,8 +425,8 @@ class TECompilerImpl : public TECompilerNode { ICHECK(ttype->dtype == x_ref->dtype && ttype->shape.size() == x_ref->shape.size()) << "function parameter does not correspond to prepared tensor"; binds[x_ref] = - BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, false, - param->virtual_device()->memory_scope); + tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, + false, param->virtual_device()->memory_scope); } } i++; @@ -437,8 +437,9 @@ class TECompilerImpl : public TECompilerNode { ICHECK(value->cached_func->outputs.size() == 1) << "Expect only one output for defined memory scope"; te::Tensor x_ref = value->cached_func->outputs[0]; - binds[x_ref] = BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, - false, key->virtual_device->memory_scope); + binds[x_ref] = + tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, + false, key->virtual_device->memory_scope); } auto func_name = value->cached_func->prim_fn_var->name_hint; VLOG(1) << "scheduling"; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index dffb8b499285..1ac0f1f1705e 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -585,6 +585,33 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array data_ = std::move(n); } +tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, + int data_alignment, int offset_factor, bool compact, + std::string memory_scope) { + DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); + auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); + bool has_any = false; + if (!compact) { + for (const auto& it : shape) { + if (it.as()) { + has_any = true; + break; + } + } + } + tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; + + PrimExpr elem_offset; + if (offset_factor != 0) { + elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); + } else { + elem_offset = PrimExpr(); + } + + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, + offset_factor, buffer_type); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get());