Skip to content
29 changes: 29 additions & 0 deletions include/tvm/relay/attrs/debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/debug.h
* \brief Auxiliary attributes for debug operators.
*/
#ifndef TVM_RELAY_ATTRS_DEBUG_H_
#define TVM_RELAY_ATTRS_DEBUG_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Options for the debug operators.
*/
struct DebugAttrs : public tvm::AttrsNode<DebugAttrs> {
EnvFunc debug_func;

TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") {
TVM_ATTR_FIELD(debug_func)
.describe("The function to use when debugging.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEBUG_H_
5 changes: 5 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ using TOpPattern = int;
*/
using TOpIsStateful = bool;

/*!
* \brief Mark the operator as non-computational.
*/
using TNonComputational = bool;

/*!
* \brief Computation description interface.
*
Expand Down
7 changes: 1 addition & 6 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import ir_pass
from .build_module import build, build_config, create_executor
from . import parser
from . import debug

# Root operators
from .op import Op
Expand Down Expand Up @@ -58,11 +59,5 @@
const = expr.const
bind = expr.bind

# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
import pdb
pdb.set_trace()

# Parser
fromtext = parser.fromtext
25 changes: 25 additions & 0 deletions python/tvm/relay/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from ..api import register_func

@register_relay_node
class InterpreterState(NodeBase):
pass

# pylint: disable=unused-argument
def _debugger_init(expr, stack):
import pdb
pdb.set_trace()

# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
_, _, _, ist = args
print("Relay Debugger")
print(" You can manipulate the expression under evaluation with the name `expr`.")
print(" You can manipulate the call stack with the name `stack`.")
print("--------------")
print("--------------")
_debugger_init(ist.current_expr, ist.stack)
2 changes: 2 additions & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# operator defs
from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
Op
from .op import debug

# Operators
from .reduce import *
Expand All @@ -13,6 +14,7 @@
from . import vision
from . import op_attrs


# operator registry
from . import _tensor
from . import _transform
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..expr import Expr
from ...api import register_func
from ...build_module import lower, build
from . import _make

@register_relay_node
class Op(Expr):
Expand Down Expand Up @@ -183,3 +184,18 @@ def schedule_injective(attrs, outputs, target):
"""Generic schedule for binary broadcast."""
with target:
return topi.generic.schedule_injective(outputs)

__DEBUG_COUNTER__ = 0

def debug(expr, debug_func=None):
"""The main entry point to the debugger."""
global __DEBUG_COUNTER__

if debug_func:
name = "debugger_func{}".format(__DEBUG_COUNTER__)
register_func(name, debug_func)
__DEBUG_COUNTER__ += 1
else:
name = ''

return _make.debug(expr, name)
63 changes: 62 additions & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/debug.h>
#include "compile_engine.h"

namespace tvm {
Expand Down Expand Up @@ -124,13 +125,48 @@ struct Stack {
};
};

/*! \brief A representation of the interpreter state which can be passed back to Python. */
class InterpreterState;

/*! \brief A container capturing the state of the interpreter. */
class InterpreterStateNode : public Node {
public:
using Frame = tvm::Map<Var, Value>;
using Stack = tvm::Array<Frame>;

/*! \brief The current expression under evaluation. */
Expr current_expr;

/*! \brief The call stack of the interpreter. */
Stack stack;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("current_expr", &current_expr);
v->Visit("stack", &stack);
}

TVM_DLL static InterpreterState make(Expr current_expr, Stack stack);

static constexpr const char* _type_key = "relay.InterpreterState";
TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node);
};

RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef);

InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
NodePtr<InterpreterStateNode> n = make_node<InterpreterStateNode>();
n->current_expr = std::move(current_expr);
n->stack = std::move(stack);
return InterpreterState(n);
}

// NOTE: the current interpreter assumes A-normal form.
// which is better for execution.
//
// It will run duplicated computations when taking program that
// contains DAG in dataflow-form.
// Conversion to ANF is recommended before running the interpretation.
//
// Conversion to ANF is recommended before running the interpretation.
class Interpreter :
public ExprFunctor<Value(const Expr& n)> {
public:
Expand Down Expand Up @@ -209,6 +245,21 @@ class Interpreter :

Value InvokePrimitiveOp(Function func,
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();

if (call_node && call_node->op == Op::Get("debug")) {
auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]);

if (dattrs->debug_func.defined()) {
dattrs->debug_func(interp_state);
} else {
RELAY_DEBUG(interp_state);
}

return args[0];
}

// Marshal the arguments.
// Handle tuple input/output by flattening them.
size_t arg_len = 0;
Expand Down Expand Up @@ -388,6 +439,16 @@ class Interpreter :
}
}

InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) {
InterpreterStateNode::Frame frame = fr.locals;
stack.push_back(frame);
}
auto state = InterpreterStateNode::make(e, stack);
return state;
}

private:
// module
Module mod_;
Expand Down
54 changes: 54 additions & 0 deletions src/relay/op/debug.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*!
* Copyright (c) 2018 by Contributors
* \file nn.cc
* \brief Property def of nn operators.
*/

#include <tvm/relay/op.h>
#include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h>
#include <vector>
#include "./type_relations.h"
#include "./op_common.h"
#include "./layout.h"

namespace tvm {
namespace relay {

Array<Tensor> DebugCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return Array<Tensor>{ topi::identity(inputs[0]) };
}

RELAY_REGISTER_OP("debug")
.describe(R"code(Enter the interpreter's debugger.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("program", "Tuple", "The program to execute before debugging.")
.set_support_level(1)
.add_type_rel("Debug", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FTVMCompute>("FTVMCompute", DebugCompute);

Expr MakeDebug(Expr expr, std::string name) {
auto dattrs = make_node<DebugAttrs>();
if (name.size() > 0) {
dattrs->debug_func = EnvFunc::Get(name);
} else {
dattrs->debug_func = EnvFunc();
}
static const Op& op = Op::Get("debug");
return CallNode::make(op, {expr}, Attrs(dattrs), {});
}

TVM_REGISTER_API("relay.op._make.debug")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
});

} // namespace relay
} // namespace tvm

32 changes: 32 additions & 0 deletions tests/python/relay/test_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from tvm.relay import var, const, create_executor
from tvm.relay.op import debug


_test_debug_hit = False

def test_debug():
global _test_debug_hit
ex = create_executor()
x = var('x', shape=(), dtype='int32')
_test_debug_hit = False
def did_exec(x):
global _test_debug_hit
_test_debug_hit = True
prog = debug(x, debug_func=did_exec)
result = ex.evaluate(prog, { x: const(1) })
assert _test_debug_hit
assert result.asnumpy() == 1

def test_debug_with_expr():
global _test_debug_hit
_test_debug_hit = False
ex = create_executor()
x = var('x', shape=(), dtype='int32')
_test_debug_hit = False
def did_exec(x):
global _test_debug_hit
_test_debug_hit = True
prog = debug(x + x * x, debug_func=did_exec)
result = ex.evaluate(prog, { x: const(2) })
assert _test_debug_hit
assert result.asnumpy() == 6