Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
5 changes: 5 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
int groups;
std::string data_layout;
std::string weight_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
Expand Down Expand Up @@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};


struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout)
.describe("The source layout of the tensor. (e.g. NCHW)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. NCHW16c)");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
2 changes: 1 addition & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ using FTVMSchedule = runtime::TypedPackedFunc<
const Array<Tensor>& outs,
const Target& target)>;

/*!
* \brief Alternate the layout of operators or replace the
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;

/*!
* \brief Forward rewriting rule for a specific op.
*
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <string>

namespace tvm {
Expand Down Expand Up @@ -173,6 +174,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import schedule
from . import module
from . import node
from . import attrs
from . import ir_builder
from . import target
from . import generic
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.node import NodeBase, register_node as _register_tvm_node
from ._ffi.function import _init_api
from . import _api_internal


@_register_tvm_node
class Attrs(NodeBase):
"""Attribute node, which is mainly use for defining attributes of relay operators.

Used by function registered in python side, such as compute, schedule and alter_layout.
Attrs is passed as the first argument to these functions.
"""
def list_field_info(self):
""" Get fields information

Returns
-------
infos: list of AttrFieldInfo
List of field information
"""
return _api_internal._AttrsListFieldInfo(self)

def keys(self):
"""Get list of names in the attribute.

Returns
-------
keys : list of str
List of keys
"""
fields = self.list_field_info()
for field in fields:
yield field.name

def __getitem__(self, item):
return self.__getattr__(item)


_init_api("tvm.attrs")
14 changes: 14 additions & 0 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key)


def register_relay_attr_node(type_key=None):
"""register relay attribute node

Parameters
----------
type_key : str or cls
The type key of the node
"""
if not isinstance(type_key, str):
return _register_tvm_node(
"relay.attrs." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)


class RelayNode(NodeBase):
"""Base class of all relay node."""
def astext(self, show_meta_data=True, annotate=None):
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
}

class BuildConfig(object):
Expand Down Expand Up @@ -157,6 +158,13 @@ def optimize(func, params=None):

if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)

if cfg.pass_enabled("AlterOpLayout"):
func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func)
func = ir_pass.alter_op_layout(func)

return func


Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ def simplify_inference(expr):
return _ir_pass.simplify_inference(expr)


def canonicalize_ops(expr):
""" Canonicalize special operators to basic operators.
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)

Parameters
----------
e: tvm.relay.Expr
The input Expression

Returns
-------
result: tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)


def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).

Expand Down Expand Up @@ -321,3 +338,22 @@ def combine_parallel_conv2d(expr):
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)


def alter_op_layout(expr):
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)
4 changes: 3 additions & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
# operator defs
from .op import get, register, register_schedule, register_compute, Op
from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
Op

# Operators
from .reduce import *
Expand All @@ -10,6 +11,7 @@
from . import nn
from . import image
from . import vision
from . import op_attrs

# operator registry
from . import _tensor
Expand Down
9 changes: 0 additions & 9 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target):
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]

register_schedule("clip", schedule_elemwise)
register_pattern("clip", OpPattern.ELEMWISE)

# concatenate
@register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]

register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
18 changes: 16 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import
import topi
from . import op as _reg
from ._reduce import _schedule_reduce
from .op import schedule_injective, OpPattern

schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
Expand All @@ -15,10 +17,22 @@
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)

# concatenate
@_reg.register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need a correct layout function for concatenate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. For concatenate, all its inputs will be transformed to old layouts. This is the default fallback case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for concat to propagate layout?

Copy link
Member Author

@merrymercy merrymercy Nov 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We should add it

return [topi.concatenate(inputs, axis=attrs.axis)]

_reg.register_schedule("concatenate", schedule_injective)
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)
22 changes: 20 additions & 2 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10):
op_name : str
The name of the op.

schedule : function
schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
The schedule function.

level : int
Expand All @@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10):
op_name : str
The name of the op.

compute : function
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
-> List[Tensor]
The compute function.

level : int
Expand All @@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10):
return register(op_name, "FTVMCompute", compute, level)


def register_alter_op_layout(op_name, alter_layout=None, level=10):
"""Register alter op layout function for an op

Parameters
----------
op_name : str
The name of the operator

alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator

level : int
The priority level
"""
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)


def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op.

Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""The attributes node used for Relay operators"""

from ...attrs import Attrs
from ..base import register_relay_attr_node

@register_relay_attr_node
class Conv2DAttrs(Attrs):
"""Attribute of a Convolution Operator"""
pass

@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
"""Attribute of a Global 2D Pooling Operator"""
pass
22 changes: 22 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None):
The computed result.
"""
return _make.slice_like(data, shape_like, axes)


def layout_transform(data, src_layout, dst_layout):
"""Transform the layout of a tensor

Parameters
----------
data : relay.Expr
The source tensor to be transformed

src_layout: str
The source layout. (e.g NCHW)

dst_layout: str
The destination layout. (e.g. NCHW16c)

Returns
-------
ret : relay.Expr
The transformed tensor.
"""
return _make.layout_transform(data, src_layout, dst_layout)
Loading