diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 5c3ab8b1ffda..f053165470fe 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -133,6 +133,9 @@ This level enables additional math and transform operators. :nosignatures: tvm.relay.image.resize + tvm.relay.vision.multibox_prior + tvm.relay.vision.multibox_transform_loc + tvm.relay.vision.nms **Level 10: Temporary Operators** @@ -160,6 +163,7 @@ Level 1 Definitions .. autofunction:: tvm.relay.mod .. autofunction:: tvm.relay.tanh .. autofunction:: tvm.relay.concatenate +.. autofunction:: tvm.relay.expand_dims .. autofunction:: tvm.relay.nn.softmax .. autofunction:: tvm.relay.nn.log_softmax .. autofunction:: tvm.relay.nn.relu @@ -236,6 +240,9 @@ Level 4 Definitions Level 5 Definitions ------------------- .. autofunction:: tvm.relay.image.resize +.. autofunction:: tvm.relay.vision.multibox_prior +.. autofunction:: tvm.relay.vision.multibox_transform_loc +.. autofunction:: tvm.relay.vision.nms Level 10 Definitions diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 0aeb7f2b1513..1e5265c07959 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -946,9 +946,11 @@ inline TVMType String2TVMType(std::string s) { char* xdelim; // emulate sscanf("%ux%u", bits, lanes) uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); if (bits != 0) t.bits = bits; + char* endpt = xdelim; if (*xdelim == 'x') { - t.lanes = static_cast(strtoul(xdelim + 1, nullptr, 10)); + t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); } + CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; return t; } diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 4725c0a7a07d..4e0852f1f7bb 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -49,7 +49,7 @@ def astype(self, dtype): result : tvm.relay.Expr The result expression. """ - return _make.dtype_cast(self, dtype) + return _make.cast(self, dtype) def __add__(self, other): if isinstance(other, Expr): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2791eaf7d9db..bc0a42d6ab30 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -4,6 +4,26 @@ from ..expr import TupleWrapper +def cast(data, dtype): + """Cast input tensor to data type. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + dtype: str + The target data type + + Returns + ------- + result : relay.Expr + The casted result. + """ + from .. import _make as _relay_make + return _relay_make.cast(data, dtype) + + def expand_dims(data, axis, num_newaxis=1): """Insert `num_newaxis` axises at the position given by `axis`. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index eb8b4f13fb3f..704324533185 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -61,7 +61,7 @@ Expr MakeCast(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay._make.dtype_cast") +TVM_REGISTER_API("relay._make.cast") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeCast, args, rv); }); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 0731ecfef40a..31e87ef04856 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -46,6 +46,11 @@ def test_cast(): assert "dtype=" in yy.astext() assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") + x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) + y = relay.cast(x, "int32") + yy = relay.ir_pass.infer_type(y) + assert "dtype=" in yy.astext() + assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") def test_clip(): a = relay.var("a", relay.TensorType((10, 4), "float32"))