Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);

} // namespace tvm

#endif // TVM_DRIVER_DRIVER_API_H_
3 changes: 2 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ class VarNode : public ExprNode {

bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(type_annotation, other->type_annotation) && equal(vid, other->vid);
return equal(type_annotation, other->type_annotation) && equal(vid, other->vid) &&
equal(virtual_device_, other->virtual_device_);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,23 @@ class DataProducer : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode);
};

/*!
* \brief Creates TIR Buffer for provided parameters
* \param shape shape of the buffer
* \param dtype data type
* \param name buffer name
* \param data_alignment alignment requirement of data pointer in bytes
* \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be
* multiple of offset_factor
User can specify data_alignment and offset_factor to be 0
* A default value will be picked.
* \param compact If the statement has already bound to a compact buffer.
* \param memory_scope memory scope of the buffer
*/
TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype,
std::string name, int data_alignment,
int offset_factor, bool compact,
std::string memory_scope = "");
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_BUFFER_H_
30 changes: 2 additions & 28 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,32 +83,6 @@ Target DefaultTargetHost(Target target) {
}
}

tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
auto data = tir::Var(name, PointerType(PrimType(storage_dtype)));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
if (it.as<tir::VarNode>()) {
has_any = true;
break;
}
}
}
tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault;

PrimExpr elem_offset;
if (offset_factor != 0) {
elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype());
} else {
elem_offset = PrimExpr();
}

return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, data_alignment,
offset_factor, buffer_type);
}

void GetBinds(const Array<ObjectRef>& args, bool compact,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
Expand All @@ -118,8 +92,8 @@ void GetBinds(const Array<ObjectRef>& args, bool compact,
if (const te::TensorNode* tensor_node = x.as<te::TensorNode>()) {
te::Tensor x_ref = GetRef<te::Tensor>(tensor_node);
if (out_binds->find(x_ref) == out_binds->end()) {
tir::Buffer buf =
BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype,
x_ref->op->name, -1, 0, compact);
out_binds->Set(x_ref, buf);
out_arg_list->push_back(buf);
} else {
Expand Down
30 changes: 29 additions & 1 deletion src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,33 @@ class TECompilerImpl : public TECompilerNode {
}
// lower the function
std::unordered_map<te::Tensor, tir::Buffer> binds;

// If we have memory scopes, need to create tir::Buffer knowing this info
size_t i = 0; // for corresponding from tensor array
for (Var param : key->source_func->params) {
if (!param->virtual_device()->memory_scope.empty()) {
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
te::Tensor x_ref = value->cached_func->inputs[i];
// verification if we have synced params and tensors
ICHECK(ttype->dtype == x_ref->dtype && ttype->shape.size() == x_ref->shape.size())
<< "function parameter does not correspond to prepared tensor";
binds[x_ref] =
tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0,
false, param->virtual_device()->memory_scope);
}
}
i++;
}
if (key->virtual_device != VirtualDevice::FullyUnconstrained() &&
!key->virtual_device->memory_scope.empty() &&
key->virtual_device->memory_scope != "global") {
ICHECK(value->cached_func->outputs.size() == 1)
<< "Expect only one output for defined memory scope";
te::Tensor x_ref = value->cached_func->outputs[0];
binds[x_ref] =
tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0,
false, key->virtual_device->memory_scope);
}
auto func_name = value->cached_func->prim_fn_var->name_hint;
VLOG(1) << "scheduling";
IRModule scheduled_module =
Expand Down Expand Up @@ -895,7 +922,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
} else {
// Cases 1 and 2: lower the primitive function for the desired target, possibly using external
// codegen.
CCacheKey key(Downcast<Function>(primitive_func), target);
CCacheKey key(Downcast<Function>(primitive_func), target,
GetVirtualDevice(GetRef<Call>(call_node)));
CachedFunc cfunc = compiler_->Lower(key, module_name_);
ICHECK(cfunc.defined());
return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args),
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation im
data_ = std::move(n);
}

CCacheKey::CCacheKey(Function source_func, Target target) {
CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) {
auto n = make_object<CCacheKeyNode>();
n->source_func = std::move(source_func);
n->target = std::move(target);
n->virtual_device = std::move(vd);
data_ = std::move(n);
}

Expand Down
7 changes: 6 additions & 1 deletion src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,13 @@ class CCacheKeyNode : public Object {
Function source_func;
/*! \brief The hardware target.*/
Target target;
/*! \brief The virtual device constrains.*/
VirtualDevice virtual_device;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source_func", &source_func);
v->Visit("target", &target);
v->Visit("virtual_device", &virtual_device);
}
/*! \return The hash value of CCacheKey. */
inline size_t Hash() const;
Expand Down Expand Up @@ -117,7 +120,8 @@ class CCacheKey : public ObjectRef {
* \param source_func The source function.
* \param target The target device.
*/
TVM_DLL CCacheKey(Function source_func, Target target);
TVM_DLL CCacheKey(Function source_func, Target target,
VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained());

const CCacheKeyNode* operator->() const { return static_cast<const CCacheKeyNode*>(get()); }
// comparator
Expand Down Expand Up @@ -244,6 +248,7 @@ inline size_t CCacheKeyNode::Hash() const {
inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
this->virtual_device == other->virtual_device &&
tvm::StructuralEqual()(this->source_func, other->source_func);
}

Expand Down
27 changes: 27 additions & 0 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,33 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
data_ = std::move(n);
}

tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact,
std::string memory_scope) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
if (it.as<tir::VarNode>()) {
has_any = true;
break;
}
}
}
tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault;

PrimExpr elem_offset;
if (offset_factor != 0) {
elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype());
} else {
elem_offset = PrimExpr();
}

return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, data_alignment,
offset_factor, buffer_type);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferNode*>(node.get());
Expand Down