From 728eeecfe33afe6ff85e171b59c26b7bdab3a462 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 18 Oct 2019 18:00:42 +0000 Subject: [PATCH 1/2] save lint --- include/tvm/relay/interpreter.h | 26 ++++++++++++++++++ python/tvm/relay/backend/interpreter.py | 5 ++++ src/relay/backend/interpreter.cc | 35 ++++++++++++++++++++----- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index a0422fa7f446..7f6aefa39f5c 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -119,6 +119,32 @@ class ClosureNode : public ValueNode { RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); +/*! \brief A Relay Recursive Closure. A closure that have a name. */ +class RecClosure; + +/*! \brief The container type of RecClosure. */ +class RecClosureNode : public ValueNode { + public: + /*! \brief The closure. */ + Closure clos; + /*! \brief variable the closure bind to. */ + Var bind; + + RecClosureNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("clos", &clos); + v->Visit("bind", &bind); + } + + TVM_DLL static RecClosure make(Closure clos, Var bind); + + static constexpr const char* _type_key = "relay.RecClosure"; + TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value); + /*! \brief A tuple value. */ class TupleValue; diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ae60b7a89b2f..1d53f6a92b07 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -72,6 +72,11 @@ class Closure(Value): """A closure produced by the interpreter.""" +@register_relay_node +class RecClosure(Value): + """A recursive closure produced by the interpreter.""" + + @register_relay_node class ConstructorValue(Value): def __init__(self, tag, fields, constructor): diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 86a4ebb4ebd2..2703b1c8634a 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ClosureNode* node, tvm::IRPrinter* p) { - p->stream << "ClosureNode(" << node->func << ")"; + p->stream << "ClosureNode(" << node->func << ", " << node->env << ")"; }); + +// TODO(@jroesch): this doesn't support mutual letrec +/* Value Implementation */ +RecClosure RecClosureNode::make(Closure clos, Var bind) { + NodePtr n = make_node(); + n->clos = std::move(clos); + n->bind = std::move(bind); + return RecClosure(n); +} + +TVM_REGISTER_API("relay._make.RecClosure") +.set_body_typed(RecClosureNode::make); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RecClosureNode* node, tvm::IRPrinter* p) { + p->stream << "RecClosureNode(" << node->clos << ")"; + }); + TupleValue TupleValueNode::make(tvm::Array value) { NodePtr n = make_node(); n->fields = value; @@ -281,7 +299,6 @@ class Interpreter : return TupleValueNode::make(values); } - // TODO(@jroesch): this doesn't support mutual letrec inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { tvm::Map captured_mod; Array free_vars = FreeVars(func); @@ -298,10 +315,8 @@ class Interpreter : // We must use mutation here to build a self referential closure. auto closure = ClosureNode::make(captured_mod, func); - auto mut_closure = - static_cast(const_cast(closure.get())); if (letrec_name.defined()) { - mut_closure->env.Set(letrec_name, closure); + return RecClosureNode::make(closure, letrec_name); } return std::move(closure); } @@ -559,7 +574,7 @@ class Interpreter : } // Invoke the closure - Value Invoke(const Closure& closure, const tvm::Array& args) { + Value Invoke(const Closure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->IsPrimitive()) { return InvokePrimitiveOp(closure->func, args); @@ -575,12 +590,16 @@ class Interpreter : locals.Set(func->params[i], args[i]); } - // Add the var to value mappings from the Closure's modironment. + // Add the var to value mappings from the Closure's environment. for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { CHECK_EQ(locals.count((*it).first), 0); locals.Set((*it).first, (*it).second); } + if (bind.defined()) { + locals.Set(bind, RecClosureNode::make(closure, bind)); + } + return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); } @@ -607,6 +626,8 @@ class Interpreter : if (const ClosureNode* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return this->Invoke(closure, args); + } else if (const RecClosureNode* closure_node = fn_val.as()) { + return this->Invoke(closure_node->clos, args, closure_node->bind); } else { LOG(FATAL) << "internal error: type error, expected function value in the call " << "position"; From c4f5187c0a54a992c911229d1a621d5d2438475c Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 23 Oct 2019 17:10:26 +0000 Subject: [PATCH 2/2] address reviewer comment --- include/tvm/relay/interpreter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 7f6aefa39f5c..f0b1e7ce8a26 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -119,7 +119,7 @@ class ClosureNode : public ValueNode { RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); -/*! \brief A Relay Recursive Closure. A closure that have a name. */ +/*! \brief A Relay Recursive Closure. A closure that has a name. */ class RecClosure; /*! \brief The container type of RecClosure. */