Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ This level enables additional math and transform operators.
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.squeeze
tvm.relay.floor
tvm.relay.ceil
tvm.relay.trunc
Expand Down Expand Up @@ -114,7 +115,7 @@ This level enables additional math and transform operators.
tvm.relay.less_equal
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.pow
tvm.relay.power
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
Expand Down Expand Up @@ -196,6 +197,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
.. autofunction:: tvm.relay.zeros
Expand All @@ -220,7 +222,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.maximum
.. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.power
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;

TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("The axis to sum over when computing softmax.");
}
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<IndexExpr> newshape;
Array<Integer> newshape;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape)
.describe("The new shape. Should be compatible with the original shape.");
Expand Down
4 changes: 2 additions & 2 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ along which to split the array.
return Array<Tensor>{
topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) };
} else {
Array<Expr> indices;
Array<Integer> indices;
for (auto i : param.indices_or_sections) {
indices.push_back(tvm::make_const(tvm::Int(32), i));
indices.push_back(static_cast<int>(i));
}
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
}
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import expr
from . import module
from . import ir_pass
from .build_module import build, create_executor
from .build_module import build, build_config, create_executor

# Root operators
from .op import Op
Expand All @@ -17,6 +17,7 @@
from . import nn
from . import vision
from . import image
from . import frontend
from . import backend

from .scope_builder import ScopeBuilder
Expand All @@ -40,6 +41,7 @@
scalar_type = ty.scalar_type

# Expr
Expr = expr.Expr
Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,18 @@ def lower(self, source_func, target=None):
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
# pylint: disable=broad-except
try:
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
except Exception:
import traceback
msg = traceback.format_exc()
msg += "Error during compile func\n"
msg += "--------------------------\n"
msg += source_func.astext(show_meta_data=False)
msg += "--------------------------\n"
raise RuntimeError(msg)

def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,4 @@ def _get_unique_name(self, name):
return name
index = self._name_map[name]
self._name_map[name] += 1
return self.get_unique_name(name + str(index))
return self._get_unique_name(name + str(index))
3 changes: 1 addition & 2 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"CombineParallelConv2D": 4,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
Expand Down Expand Up @@ -157,7 +157,6 @@ def optimize(func, params=None):

if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)

return func


Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Relay frontends."""
from __future__ import absolute_import

from .mxnet import from_mxnet
129 changes: 129 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Common utilities"""
from __future__ import absolute_import as _abs


class RequiredAttr(object):
"""Dummpy class to represent required attr"""
pass


class StrAttrsDict(object):
"""Helper class to parse attrs stored as Dict[str, str].

Parameters
----------
attrs : Dict[str, str]
The attributes to be used.
"""
def __init__(self, attrs):
self.attrs = attrs

def get_float(self, key, default=RequiredAttr()):
"""Get float attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
return float(self.attrs[key])
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_int(self, key, default=RequiredAttr()):
"""Get int attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
val = self.attrs[key]
if val == "None":
return None
return int(val)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_str(self, key, default=RequiredAttr()):
"""Get str attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
return self.attrs[key]
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_int_tuple(self, key, default=RequiredAttr()):
"""Get int tuple attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default

def get_bool(self, key, default=RequiredAttr()):
"""Get bool tuple attribute

Parameters
----------
key : str
The attribute key

default : float
The default value.

Returns
-------
value : The result
"""
if key in self.attrs:
val = self.attrs[key]
return val.strip().lower() in ['true', '1', 't', 'y', 'yes']
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
Loading