Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ This level enables additional math and transform operators.
:nosignatures:

tvm.relay.image.resize
tvm.relay.vision.multibox_prior
tvm.relay.vision.multibox_transform_loc
tvm.relay.vision.nms


**Level 10: Temporary Operators**
Expand Down Expand Up @@ -160,6 +163,7 @@ Level 1 Definitions
.. autofunction:: tvm.relay.mod
.. autofunction:: tvm.relay.tanh
.. autofunction:: tvm.relay.concatenate
.. autofunction:: tvm.relay.expand_dims
.. autofunction:: tvm.relay.nn.softmax
.. autofunction:: tvm.relay.nn.log_softmax
.. autofunction:: tvm.relay.nn.relu
Expand Down Expand Up @@ -236,6 +240,9 @@ Level 4 Definitions
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
.. autofunction:: tvm.relay.vision.multibox_prior
.. autofunction:: tvm.relay.vision.multibox_transform_loc
.. autofunction:: tvm.relay.vision.nms


Level 10 Definitions
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,11 @@ inline TVMType String2TVMType(std::string s) {
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def astype(self, dtype):
result : tvm.relay.Expr
The result expression.
"""
return _make.dtype_cast(self, dtype)
return _make.cast(self, dtype)

def __add__(self, other):
if isinstance(other, Expr):
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from ..expr import TupleWrapper


def cast(data, dtype):
"""Cast input tensor to data type.

Parameters
----------
data : relay.Expr
The input data to the operator.

dtype: str
The target data type

Returns
-------
result : relay.Expr
The casted result.
"""
from .. import _make as _relay_make
return _relay_make.cast(data, dtype)


def expand_dims(data, axis, num_newaxis=1):
"""Insert `num_newaxis` axises at the position given by `axis`.

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Expr MakeCast(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay._make.dtype_cast")
TVM_REGISTER_API("relay._make.cast")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
});
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def test_cast():
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")

x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
y = relay.cast(x, "int32")
yy = relay.ir_pass.infer_type(y)
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")

def test_clip():
a = relay.var("a", relay.TensorType((10, 4), "float32"))
Expand Down