Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ class ClosureNode : public ValueNode {

RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
public:
/*! \brief The closure. */
Closure clos;
/*! \brief variable the closure bind to. */
Var bind;

RecClosureNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}

TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);

/*! \brief A tuple value. */
class TupleValue;

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class Closure(Value):
"""A closure produced by the interpreter."""


@register_relay_node
class RecClosure(Value):
"""A recursive closure produced by the interpreter."""


@register_relay_node
class ConstructorValue(Value):
def __init__(self, tag, fields, constructor):
Expand Down
35 changes: 28 additions & 7 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
p->stream << "ClosureNode(" << node->func << ")";
p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
});


// TODO(@jroesch): this doesn't support mutual letrec
/* Value Implementation */
RecClosure RecClosureNode::make(Closure clos, Var bind) {
NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
n->clos = std::move(clos);
n->bind = std::move(bind);
return RecClosure(n);
}

TVM_REGISTER_API("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
p->stream << "RecClosureNode(" << node->clos << ")";
});

TupleValue TupleValueNode::make(tvm::Array<Value> value) {
NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
n->fields = value;
Expand Down Expand Up @@ -281,7 +299,6 @@ class Interpreter :
return TupleValueNode::make(values);
}

// TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
Expand All @@ -298,10 +315,8 @@ class Interpreter :

// We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func);
auto mut_closure =
static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
if (letrec_name.defined()) {
mut_closure->env.Set(letrec_name, closure);
return RecClosureNode::make(closure, letrec_name);
}
return std::move(closure);
}
Expand Down Expand Up @@ -559,7 +574,7 @@ class Interpreter :
}

// Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args);
Expand All @@ -575,12 +590,16 @@ class Interpreter :
locals.Set(func->params[i], args[i]);
}

// Add the var to value mappings from the Closure's modironment.
// Add the var to value mappings from the Closure's environment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second);
}

if (bind.defined()) {
locals.Set(bind, RecClosureNode::make(closure, bind));
}

return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
}

Expand All @@ -607,6 +626,8 @@ class Interpreter :
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
} else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
Expand Down