Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return node;
}

/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */
TypedPackedFunc() {}
/*! \brief constructor from null */
TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief construct by wrap a PackedFunc
*
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@ def register_relay_node(type_key=None):


class RelayNode(NodeBase):
def astext(self):
"""Base class of all relay node."""
def astext(self, annotate=None):
"""Get the text format of the expression.

Returns
-------
text : str
The text format of the expression.

annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
"""
return _expr._text_print(self)
return _expr.RelayPrint(self, annotate)


@register_relay_node
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,13 @@ def build(func,
else:
tophub_context = autotvm.util.EmptyContext()

cfg = BuildConfig.current

with tophub_context:
func = optimize(func)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
Expand Down Expand Up @@ -477,7 +476,7 @@ def astext(self):
text : str
The text format of the tuple expression.
"""
return _expr._text_print(self.tuple_value)
return self.tuple_value.astext()

def __getitem__(self, index):
if index >= len(self):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,20 @@ def structural_hash(value):
raise TypeError(msg)


def fuse_ops(expr):
def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

opt_level : int
The level of fuse optimization.

Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr)
return _ir_pass.FuseOps(expr, opt_level)
58 changes: 57 additions & 1 deletion src/common/arena.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,29 @@ class Arena {
/*!
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
* \note The space of T is not initialized.
*/
template<typename T>
T* Alloc() {
T* allocate_() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
}
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template<typename T, typename... Args>
T* make(Args&&... args) {
T* ptr = allocate_<T>();
new (ptr) T(std::forward<Args>(args)...);
return ptr;
}

private:
// page size 16 KB
Expand Down Expand Up @@ -87,6 +105,44 @@ class Arena {
}
};

/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};

} // namespace common
} // namespace tvm
#endif // TVM_COMMON_ARENA_H_
23 changes: 23 additions & 0 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,29 @@ class ScheduleGetter :
return {};
}

Array<Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = TVMType2Type(op->data->dtype);
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
return {value};
}

Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
Expand Down
26 changes: 19 additions & 7 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
: annotate_(annotate) {}
/*!
* \brief Print a node to string.
* \param node.
Expand Down Expand Up @@ -279,11 +281,11 @@ class TextPrinter :

TextValue VisitExpr_(const CallNode* op) final {
// possibly through meta-data
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue call_op = GetValue(op->op);
TextValue id = this->AllocTempVar();
this->PrintIndent();

Expand Down Expand Up @@ -532,7 +534,9 @@ class TextPrinter :
*/
void PrintOptionalInfo(const Expr& expr) {
// additional information in comment.
if (expr->checked_type_.defined()) {
if (annotate_ != nullptr) {
stream_ << " # " << annotate_(expr);
} else if (expr->checked_type_.defined()) {
stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_);
}
Expand Down Expand Up @@ -678,14 +682,19 @@ class TextPrinter :
name = "%" + name;
}
TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)) << "Duplicated variable " << var;
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
memo_[var] = TextValue(val.name + "-malformed-ir");
}
memo_[var] = val;
return val;
}

private:
class AttrPrinter;
friend class AttrPrinter;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */
Expand Down Expand Up @@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os << ", " << meta_.GetMetaNode(attrs);
}

std::string RelayPrint(const NodeRef& node) {
return TextPrinter().Print(node);
std::string RelayPrint(const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node);
}

TVM_REGISTER_API("relay._expr._text_print")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint);
TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(
const NodeRef&,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);

} // namespace relay
} // namespace tvm
18 changes: 1 addition & 17 deletions src/relay/pass/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h"

namespace tvm {
Expand Down Expand Up @@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}

class BackwardPrep : private ExprVisitor {
public:
Expand Down
Loading