Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(String name_hint);
TVM_DLL explicit GlobalVar(String name_hint, Type type = {});

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
SEScope se_scope = SEScope::FullyUnconstrained();

/*!
* \brief If fales (the default), the result of the "on_device" call is not constrained to be
* \brief If false (the default), the result of the "on_device" call is not constrained to be
* \p se_scope.
*/
bool constrain_result = false;
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ class Var : public Expr {
*/
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());

/*!
* \brief Return a globally fresh name. Helps with debugging to follow the same
* variable between passes and sub-expressions.
*
* TODO(mbs): Replace with name creation w.r.t. scopes once available as part of
* name gen overhaul.
*/
static Var GenSym(Type type_annotation = {}, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,21 @@ using MemoryScope = String;
*
*/
class SEScopeNode : public AttrsNode<SEScopeNode> {
public:
private:
/*!
* \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is
* \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is
* known then this will be equal to \p target->kind->device_type. If \p target is null then the
* target is to be determined later.
*
* This is needed to support the legacy "on_device" and "device_copy" calls which only allow
* a \p DLDeviceTypes (as an integer) to be given.
*
* kInvalidDeviceType denotes unconstrained.
* kInvalidDeviceType denotes unconstrained. An int since the DLDeviceType enum representation
* is not fixed. Private to discourage further int vs DLDeviceType confusion.
*/
int /* actually DLDeviceType */ device_type_int;

public:
DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }

/*!
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class GlobalVar(RelayExpr):
The name of the variable.
"""

def __init__(self, name_hint):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
def __init__(self, name_hint, type_annot=None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot)

def __call__(self, *args):
"""Call the global variable.
Expand Down
7 changes: 5 additions & 2 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint) {
GlobalVar::GlobalVar(String name_hint, Type type) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
n->checked_type_ = std::move(type);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(GlobalVarNode);

TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); });
TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) {
return GlobalVar(name, type);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ExtractConstantsMutator : public MixedModeMutator {
private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const FunctionNode* function) final {
Function func = GetRef<Function>(function);
function_to_constants_.Set(func, Array<Constant>{});
Expand All @@ -56,7 +58,7 @@ class ExtractConstantsMutator : public MixedModeMutator {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
}
return func;
return std::move(func);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class GenerateConstantsMutator : public MixedModeMutator {
if (clip_call) {
ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
}
return ret_call;
return std::move(ret_call);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
int clip_max;
};

using codegen::CodeGenCHost::VisitStmt_;

/*! * \brief Emits CMSIS-NN APIs for every call_extern */
void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final {
if (!op->op.same_as(builtin::call_extern())) {
CodeGenCHost::VisitExpr_(op, os);
return;
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode {
*rv = this->output_.external_mods;
});
} else if (name == "get_devices") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); });
return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); });
} else if (name == "get_metadata") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; });
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,10 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));

// Now that we have PrimFuncs, flow and solve SEScope constraints again to account for
// any memory scopes which lowering has settled on.
pass_seqs.push_back(transform::PlanDevices(config_));

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Clause WithFields(Clause clause, Optional<Pattern> opt_lhs, Optional<Expr> opt_r
cow_clause_node->lhs = lhs;
cow_clause_node->rhs = rhs;
}
return std::move(clause);
return clause;
}

TVM_REGISTER_NODE_TYPE(ClauseNode);
Expand Down Expand Up @@ -168,7 +168,7 @@ Match WithFields(Match match, Optional<Expr> opt_data, Optional<Array<Clause>> o
cow_match_node->complete = complete;
cow_match_node->span = span;
}
return std::move(match);
return match;
}

TVM_REGISTER_NODE_TYPE(MatchNode);
Expand Down
49 changes: 28 additions & 21 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
cow_tuple_node->virtual_device_ = virtual_device;
cow_tuple_node->span = span;
}
return std::move(tuple);
return tuple;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -124,6 +124,13 @@ Var::Var(Id vid, Type type_annotation, Span span) {
data_ = std::move(n);
}

/* static */ Var Var::GenSym(Type type_annotation, Span span) {
static size_t next_id = std::atomic<size_t>(0);
std::ostringstream os;
os << "x_" << next_id++;
return Var(os.str(), std::move(type_annotation), std::move(span));
}

Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation,
Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
Id vid = opt_vid.value_or(var->vid);
Expand All @@ -141,7 +148,7 @@ Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation
cow_var_node->virtual_device_ = virtual_device;
cow_var_node->span = span;
}
return std::move(var);
return var;
}

TVM_REGISTER_NODE_TYPE(VarNode);
Expand Down Expand Up @@ -219,7 +226,7 @@ Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args
cow_call_node->virtual_device_ = virtual_device;
cow_call_node->span = span;
}
return std::move(call);
return call;
}

TVM_REGISTER_NODE_TYPE(CallNode);
Expand Down Expand Up @@ -264,7 +271,7 @@ Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optiona
cow_let_node->virtual_device_ = virtual_device;
cow_let_node->span = span;
}
return std::move(let);
return let;
}

TVM_REGISTER_NODE_TYPE(LetNode);
Expand Down Expand Up @@ -308,7 +315,7 @@ If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branc
cow_if_node->virtual_device_ = virtual_device;
cow_if_node->span = span;
}
return std::move(if_expr);
return if_expr;
}

TVM_REGISTER_NODE_TYPE(IfNode);
Expand Down Expand Up @@ -350,7 +357,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
cow_tuple_get_item_node->span = span;
cow_tuple_get_item_node->virtual_device_ = virtual_device;
}
return std::move(tuple_get_item);
return tuple_get_item;
}

TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
Expand Down Expand Up @@ -385,7 +392,7 @@ RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value,
cow_ref_create_node->virtual_device_ = virtual_device;
cow_ref_create_node->span = span;
}
return std::move(ref_create);
return ref_create;
}

TVM_REGISTER_NODE_TYPE(RefCreateNode);
Expand Down Expand Up @@ -420,7 +427,7 @@ RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref, Optional<SEScope> o
cow_ref_read_node->virtual_device_ = virtual_device;
cow_ref_read_node->span = span;
}
return std::move(ref_read);
return ref_read;
}

TVM_REGISTER_NODE_TYPE(RefReadNode);
Expand Down Expand Up @@ -457,7 +464,7 @@ RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> o
cow_ref_write_node->virtual_device_ = virtual_device;
cow_ref_write_node->span = span;
}
return std::move(ref_write);
return ref_write;
}

TVM_REGISTER_NODE_TYPE(RefWriteNode);
Expand Down Expand Up @@ -510,29 +517,29 @@ inline void Dismantle(const Expr& expr) {
stack.top().second = true;

// special handling
if (const CallNode* op = node.as<CallNode>()) {
if (const auto* call_node = node.as<CallNode>()) {
// do not process args if used elsewhere
if (op->args.use_count() < 2) {
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
if (call_node->args.use_count() < 2) {
for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) {
fpush_to_stack(*it);
}
}
} else if (const TupleNode* op = node.as<TupleNode>()) {
} else if (const auto* tuple_node = node.as<TupleNode>()) {
// do not process fields if used elsewhere
if (op->fields.use_count() < 2) {
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
if (tuple_node->fields.use_count() < 2) {
for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) {
fpush_to_stack(*it);
}
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
} else if (const auto* tuple_get_item_node = node.as<TupleGetItemNode>()) {
// do not process tuple if used elsewhere
if (op->tuple.use_count() < 2) {
fpush_to_stack(op->tuple);
if (tuple_get_item_node->tuple.use_count() < 2) {
fpush_to_stack(tuple_get_item_node->tuple);
}
} else if (const LetNode* op = node.as<LetNode>()) {
} else if (const auto* let_node = node.as<LetNode>()) {
// do not process let if used elsewhere
if (op->body.use_count() < 2) {
fpush_to_stack(op->body);
if (let_node->body.use_count() < 2) {
fpush_to_stack(let_node->body);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params, Optional
cow_function_node->virtual_device_ = virtual_device;
cow_function_node->span = span;
}
return std::move(function);
return function;
}

FuncType FunctionNode::func_type_annotation() const {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool cons
ICHECK(inner == outer)
<< "Cannot constrain intermediate result of nested on_device calls to different SEScopes";
}
// We can now ignore the intermediate constraints, if any.
// We can now ignore the middle constraint.
// If the outer on_device has any constraint then use se_scope given for it.
// Otherwise we can use the existing inner se_scope.
return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner,
constrain_outer, constrain_inner);
} else {
Expand Down
9 changes: 6 additions & 3 deletions src/relay/op/memory/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,18 @@ struct OnDeviceProps {
};

/*!
* \brief As for OnDevice, but taking all fields other than \p body from \p props.
* \brief Wraps \p body in an "on_device" CallNode, taking all fields other than \p body from \p
* props.
*/
inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) {
return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body);
}

/*!
* \brief As for OnDevice, but don't constrain the body or result to any particular virtual device.
* This allows a "device_copy" when required.
* \brief Wraps \p body in an "on_device" CallNode, but don't constrain the body or result to
* any particular virtual device. This allows a "device_copy" to be inserted by PlanDevices
* where required, while at the same time not introducing unnecessary freedom in the device
* choices.
*/
inline Call OnDeviceCopyOk(Expr body) {
return OnDevice(std::move(body), SEScope::FullyUnconstrained(),
Expand Down
12 changes: 9 additions & 3 deletions src/relay/transforms/device_domains.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,15 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get());
CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get());

// TODO(mbs): Support call_lowered to PrimFuncs.
ICHECK(!call_lowered_props.lowered_func.defined());
if (on_device_props.body.defined()) {
if (call_lowered_props.lowered_func.defined()) {
// Presumably we've already seen the call to the "primitive" Function from which this lowered
// function was derived in an earlier PlanDevices pass. Thus we've already established that
// all the argument and result devices domains must be equal, ignoring memory scopes.
// So at this point we'll let all the arguments and result be free so that memory scopes can
// differ.
// TODO(mbs): As per header comments, need to revisit when can setup sub-SEScope constraints.
return DomainFor(call_lowered_props.lowered_func);
} else if (on_device_props.body.defined()) {
// By default:
// on_device(expr, se_scope=<t>)
// on_device : fn(<t>):?x?
Expand Down
Loading