diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 45a938247cc8..48800b193cb4 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -165,7 +165,6 @@ TVM_DLL runtime::Module build(const Map& input, const Target& * \return The built module that contains code for different processors. */ TVM_DLL runtime::Module build(const Map& input, const Target& target_host); - } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6b014c8478d8..bd094a7f6905 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -218,7 +218,8 @@ class VarNode : public ExprNode { bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal(type_annotation, other->type_annotation) && equal(vid, other->vid); + return equal(type_annotation, other->type_annotation) && equal(vid, other->vid) && + equal(virtual_device_, other->virtual_device_); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index ca7faf1cdefb..d7a2aec0b972 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -295,6 +295,23 @@ class DataProducer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode); }; +/*! + * \brief Creates TIR Buffer for provided parameters + * \param shape shape of the buffer + * \param dtype data type + * \param name buffer name + * \param data_alignment alignment requirement of data pointer in bytes + * \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be + * multiple of offset_factor + User can specify data_alignment and offset_factor to be 0 + * A default value will be picked. + * \param compact If the statement has already bound to a compact buffer. + * \param memory_scope memory scope of the buffer + */ +TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, + std::string name, int data_alignment, + int offset_factor, bool compact, + std::string memory_scope = ""); } // namespace tir } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7f015e7ca2b9..0446347eca2c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -83,32 +83,6 @@ Target DefaultTargetHost(Target target) { } } -tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, - int data_alignment, int offset_factor, bool compact) { - DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - auto data = tir::Var(name, PointerType(PrimType(storage_dtype))); - bool has_any = false; - if (!compact) { - for (const auto& it : shape) { - if (it.as()) { - has_any = true; - break; - } - } - } - tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; - - PrimExpr elem_offset; - if (offset_factor != 0) { - elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); - } else { - elem_offset = PrimExpr(); - } - - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, - offset_factor, buffer_type); -} - void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { @@ -118,8 +92,8 @@ void GetBinds(const Array& args, bool compact, if (const te::TensorNode* tensor_node = x.as()) { te::Tensor x_ref = GetRef(tensor_node); if (out_binds->find(x_ref) == out_binds->end()) { - tir::Buffer buf = - BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, + x_ref->op->name, -1, 0, compact); out_binds->Set(x_ref, buf); out_arg_list->push_back(buf); } else { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e9491b0a8901..08fa18b61e16 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -414,6 +414,33 @@ class TECompilerImpl : public TECompilerNode { } // lower the function std::unordered_map binds; + + // If we have memory scopes, need to create tir::Buffer knowing this info + size_t i = 0; // for corresponding from tensor array + for (Var param : key->source_func->params) { + if (!param->virtual_device()->memory_scope.empty()) { + for (const auto& ttype : FlattenTupleType(param->checked_type())) { + te::Tensor x_ref = value->cached_func->inputs[i]; + // verification if we have synced params and tensors + ICHECK(ttype->dtype == x_ref->dtype && ttype->shape.size() == x_ref->shape.size()) + << "function parameter does not correspond to prepared tensor"; + binds[x_ref] = + tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, + false, param->virtual_device()->memory_scope); + } + } + i++; + } + if (key->virtual_device != VirtualDevice::FullyUnconstrained() && + !key->virtual_device->memory_scope.empty() && + key->virtual_device->memory_scope != "global") { + ICHECK(value->cached_func->outputs.size() == 1) + << "Expect only one output for defined memory scope"; + te::Tensor x_ref = value->cached_func->outputs[0]; + binds[x_ref] = + tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, + false, key->virtual_device->memory_scope); + } auto func_name = value->cached_func->prim_fn_var->name_hint; VLOG(1) << "scheduling"; IRModule scheduled_module = @@ -895,7 +922,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // Cases 1 and 2: lower the primitive function for the desired target, possibly using external // codegen. - CCacheKey key(Downcast(primitive_func), target); + CCacheKey key(Downcast(primitive_func), target, + GetVirtualDevice(GetRef(call_node))); CachedFunc cfunc = compiler_->Lower(key, module_name_); ICHECK(cfunc.defined()); return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 8715900c0c4a..17e3d573053f 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -66,10 +66,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey::CCacheKey(Function source_func, Target target) { +CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->virtual_device = std::move(vd); data_ = std::move(n); } diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 55f221ac8ba0..ac2619826019 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -82,10 +82,13 @@ class CCacheKeyNode : public Object { Function source_func; /*! \brief The hardware target.*/ Target target; + /*! \brief The virtual device constrains.*/ + VirtualDevice virtual_device; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); + v->Visit("virtual_device", &virtual_device); } /*! \return The hash value of CCacheKey. */ inline size_t Hash() const; @@ -117,7 +120,8 @@ class CCacheKey : public ObjectRef { * \param source_func The source function. * \param target The target device. */ - TVM_DLL CCacheKey(Function source_func, Target target); + TVM_DLL CCacheKey(Function source_func, Target target, + VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained()); const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator @@ -244,6 +248,7 @@ inline size_t CCacheKeyNode::Hash() const { inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && + this->virtual_device == other->virtual_device && tvm::StructuralEqual()(this->source_func, other->source_func); } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index dffb8b499285..1ac0f1f1705e 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -585,6 +585,33 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array data_ = std::move(n); } +tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, + int data_alignment, int offset_factor, bool compact, + std::string memory_scope) { + DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); + auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); + bool has_any = false; + if (!compact) { + for (const auto& it : shape) { + if (it.as()) { + has_any = true; + break; + } + } + } + tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; + + PrimExpr elem_offset; + if (offset_factor != 0) { + elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); + } else { + elem_offset = PrimExpr(); + } + + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, + offset_factor, buffer_type); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get());