Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.nn.relu
tvm.relay.nn.dropout
tvm.relay.nn.batch_norm
tvm.relay.nn.bias_add



**Level 2: Convolutions**
Expand Down Expand Up @@ -85,8 +87,13 @@ This level enables additional math and transform operators.
tvm.relay.abs
tvm.relay.negative
tvm.relay.take
tvm.relay.zeros
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.full
tvm.relay.full_like
tvm.relay.cast


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -151,6 +158,9 @@ Level 1 Definitions
.. autofunction:: tvm.relay.nn.softmax
.. autofunction:: tvm.relay.nn.log_softmax
.. autofunction:: tvm.relay.nn.relu
.. autofunction:: tvm.relay.nn.dropout
.. autofunction:: tvm.relay.nn.batch_norm
.. autofunction:: tvm.relay.nn.bias_add


Level 2 Definitions
Expand Down Expand Up @@ -185,6 +195,9 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast


Level 4 Definitions
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@
namespace tvm {
namespace relay {

/*!
* \brief Add a 1D Tensor to an axis of a data.
*
* \note bias_add is a special add operator that is in nn
* and enables automatic derivation of bias's shape.
* You can directly use add for more generalized case.
*/
struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
int axis;

TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The axis to add the bias")
.set_default(1);
}
};

/*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
namespace tvm {
namespace relay {

/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("Target data type");
}
}; // struct CastAttrs.

/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
int axis;
Expand Down
30 changes: 19 additions & 11 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,17 @@ class ExprFunctor<R(const Expr& n, Args...)> {
}
};

/*! \brief A simple visitor wrapper around ExprFunctor.
/*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
*
* Exposes two visitors with default traversal strategies, one
* which doesn't compute a result but can mutate internal state,
* and another which functionally builds a new Expr.
* ExprVisitor treats Expr as dataflow graph,
* and only visit each Expr node once.
*/

class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
class ExprVisitor
: public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
public:
void VisitExpr(const Expr& expr) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
Expand All @@ -132,13 +134,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);

private:
// internal visited flag.
std::unordered_set<const Node*> visited_;
};

/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
/*!
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*/
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expand Down
29 changes: 10 additions & 19 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,26 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/
bool WellFormed(const Expr& e);

/*! \brief Get free variables from expression e.
/*! \brief Get free Vars from expr in PostDFS order.
*
* Free variables are variables that are not bound by a let or a function parameter in the context.
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
*
* \param e the expression.
* \param expr the expression.
*
* \return the set of free variable.
* \return List of free vars, in the PostDFS order visited by expr.
*/
tvm::Array<Var> FreeVariables(const Expr& e);
tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get free type parameters from expression e.
/*! \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param e the expression.
* \param expr the expression.
*
* \return the set of free type variables.
* \return List of free vars, in the PostDFS order visited by expr.
*/
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);

/*! \brief Get free type parameters from type t.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param t the type.
*
* \return the set of free type variables.
*/
tvm::Array<TypeVar> FreeTypeVariables(const Type& t);
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.IntImm, dtype, value)

def __int__(self):
return self.value


@register_node
class UIntImm(ConstExpr):
Expand Down
32 changes: 28 additions & 4 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .base import RelayNode, register_relay_node
from . import _make
from . import ty as _ty
from .._ffi import base as _base, node as _node
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert

Expand All @@ -28,6 +28,25 @@ def checked_type(self):
" the checked_type for this node")
return ret

def astype(self, dtype):
"""Cast the content type of the current data to dtype.

Parameters
----------
dtype : str
The target data type.

Note
----
This function only works for TensorType Exprs.

Returns
-------
result : tvm.relay.Expr
The result expression.
"""
return _make.dtype_cast(self, dtype)


@register_relay_node
class Constant(Expr):
Expand Down Expand Up @@ -62,6 +81,9 @@ def __getitem__(self, index):
def __len__(self):
return len(self.fields)

def astype(self, _):
raise TypeError("astype cannot be used on tuple")


@register_relay_node
class Var(Expr):
Expand Down Expand Up @@ -238,7 +260,7 @@ def __init__(self, tuple_value, index):
_make.TupleGetItem, tuple_value, index)


class TupleWrapper(_node.NodeGeneric):
class TupleWrapper(object):
"""TupleWrapper.

This class is a Python wrapper for a Relay tuple of known size.
Expand All @@ -257,10 +279,9 @@ def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size

def asnode(self):
def astuple(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""

return self.tuple_value

def __getitem__(self, index):
Expand All @@ -275,6 +296,9 @@ def __repr__(self):
return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")")

def astype(self, _):
raise TypeError("astype cannot be used on tuple")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tuple -> tuplewrapper?



def var(name_hint,
type_annotation=None,
Expand Down
40 changes: 23 additions & 17 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ def infer_type(expr, env=None):
Parameters
----------
expr: tvm.relay.Expr
The input expression.
The input expression.

env: Optional[tvm.relay.Environment]
The global environment.
The global environment.


Returns
-------
checked_expr : tvm.relay.Expr
The checked expression.
The checked expression.
"""
return _ir_pass.infer_type(expr, env)

Expand All @@ -35,12 +35,12 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
The input expression
The input expression

Returns
-------
well_form : bool
whether the input expression is well formed
Whether the input expression is well formed
"""
return _ir_pass.well_formed(expr)

Expand All @@ -52,15 +52,15 @@ def check_kind(t, env=None):
Parameters
----------
t: tvm.relay.Type
The type to check
The type to check

env: tvm.relay.Environment, optional
The global environment
The global environment

Returns
-------
well_kinded : bool
whether the input type is well kinded.
whether the input type is well kinded.

Examples
--------
Expand All @@ -75,20 +75,26 @@ def check_kind(t, env=None):
return _ir_pass.check_kind(t)


def free_vars(e):
"""Get free variables from expression e.
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.

Parameters
----------
e: tvm.relay.Expr
The input expression
expr: tvm.relay.Expr
The input expression

Returns
-------
free : List[tvm.relay.Var]
The list of free variables
The list of free variables in post DFS order.

Note
----
The fact that Vars are post-DFS ordred are useful in
neural networks: usually this means weights of previous
are ordered first.
"""
return _ir_pass.free_vars(e)
return _ir_pass.free_vars(expr)


def free_type_vars(expr):
Expand Down Expand Up @@ -130,15 +136,15 @@ def alpha_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
One of the input Expression.

rhs: tvm.relay.Expr
One of the input Expression.
One of the input Expression.

Returns
-------
result: bool
True iff lhs is alpha equal to rhs.
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))

Expand Down
Loading