From 3b4521b10e8311b569b0e5afb3e46fde7dfde612 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 24 Jun 2021 14:23:37 -0600 Subject: [PATCH 01/12] rename resize to resize2d --- include/tvm/relay/attrs/image.h | 8 ++--- python/tvm/relay/frontend/mxnet.py | 2 +- python/tvm/relay/frontend/onnx.py | 20 ++++++++--- python/tvm/relay/frontend/pytorch.py | 4 +-- python/tvm/relay/frontend/tflite.py | 2 +- python/tvm/relay/op/dyn/image/_image.py | 18 +++++----- python/tvm/relay/op/image/_image.py | 30 ++++++++-------- python/tvm/relay/op/image/image.py | 8 ++--- python/tvm/relay/op/op_attrs.py | 16 ++++----- python/tvm/topi/image/resize.py | 2 +- python/tvm/topi/nn/upsampling.py | 2 +- src/relay/op/dyn/image/resize.cc | 26 +++++++------- src/relay/op/image/resize.cc | 44 +++++++++++------------ src/relay/op/make_op.h | 6 ++-- src/relay/transforms/dynamic_to_static.cc | 10 +++--- 15 files changed, 104 insertions(+), 94 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index baceb04958f0..28ef3fb03aa3 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -33,7 +33,7 @@ namespace tvm { namespace relay { /*! \brief Attributes used in image resize operator */ -struct ResizeAttrs : public tvm::AttrsNode { +struct Resize2DAttrs : public tvm::AttrsNode { Array size; std::string layout; std::string method; @@ -43,7 +43,7 @@ struct ResizeAttrs : public tvm::AttrsNode { int bicubic_exclude; DataType out_dtype; - TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { + TVM_DECLARE_ATTRS(Resize2DAttrs, "relay.attrs.Resize2DAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); TVM_ATTR_FIELD(layout).set_default("NCHW").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." @@ -80,14 +80,14 @@ struct ResizeAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in image resize3d operator */ -struct Resize3dAttrs : public tvm::AttrsNode { +struct Resize3DAttrs : public tvm::AttrsNode { Array size; String layout; String method; String coordinate_transformation_mode; DataType out_dtype; - TVM_DECLARE_ATTRS(Resize3dAttrs, "relay.attrs.Resize3dAttrs") { + TVM_DECLARE_ATTRS(Resize3DAttrs, "relay.attrs.Resize3DAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3b940bd15f5b..59b4e99de999 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -963,7 +963,7 @@ def _mx_resize(inputs, attrs): if scale_width is not None: width = (scale_width * shape[3]).astype("int32") size = (height, width) - return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") + return _op.image.resize2d(inputs[0], size, coordinate_transformation_mode="align_corners") def _mx_amp_multicast(inputs, attrs): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7135fccdf43b..fd88cddb94ac 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -457,6 +457,7 @@ def _impl_v1(cls, inputs, attr, params): kernel_type = infer_type(inputs[1]) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + print(input_shape, kernel_shapes) if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] @@ -1364,7 +1365,13 @@ def _impl_v10(cls, inputs, attr, params): ends = inputs[2] axes = inputs[3] steps = inputs[4] - + print("----------Slice------------") + print(inputs[0]) + print(inputs[1]) + print(inputs[2]) + print(inputs[3]) + print(inputs[4]) + print("----------/Slice------------") ishape = infer_shape(inputs[0]) data_rank = len(ishape) @@ -2398,13 +2405,16 @@ def _impl_v10(cls, inputs, attr, params): scale = inputs[1] size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - layout = "NCHW" # ONNX assumes NCHW layout + ndims = len(infer_shape(inputs[0])) + layout = {3: "NCW", 4: "NCHW", 5: "NCDHW"}[ndims] out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric") + return _op.image.resize2d(inputs[0], out_size, layout, method, "asymmetric") @classmethod def _impl_v11(cls, inputs, attr, params): - layout = "NCHW" # ONNX assumes NCHW layout + print({**attr}) + ndims = len(infer_shape(inputs[0])) + layout = {3: "NCH", 4: "NCHW", 5: "NCDHW"}[ndims] mode = attr.get("mode").decode("ascii") if mode == "nearest": @@ -2435,7 +2445,7 @@ def _impl_v11(cls, inputs, attr, params): size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize( + return _op.image.resize2d( inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude ) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 00fa9f597d06..5c252739f190 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1826,7 +1826,7 @@ def upsample(inputs, input_types): coord_trans = "half_pixel" def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(x, out_size, "NCHW", method, coord_trans) if self.is_quantized_tensor(data): # input qparams are manually appended by us @@ -2203,7 +2203,7 @@ def interpolate(self, inputs, input_types): else: coord_trans = "half_pixel" - return _op.image.resize(data, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans) def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a47fdf0141b5..0dee44cd86ec 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -647,7 +647,7 @@ def _convert_resize(self, method, op): coord_trans = "align_corners" if align_corners else "asymmetric" if bilinear_method and input_tensor.qnn_params: in_expr = self.dequantize(in_expr, input_tensor) - out = _op.image.resize( + out = _op.image.resize2d( in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans ) if bilinear_method and output_tensor.qnn_params: diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 32bd88456ffc..7f66a69b4803 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -26,8 +26,8 @@ # resize -@reg.register_compute("dyn.image.resize") -def compute_resize(attrs, inputs, out_type): +@reg.register_compute("dyn.image.resize2d") +def compute_resize2d(attrs, inputs, out_type): layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode @@ -36,7 +36,7 @@ def compute_resize(attrs, inputs, out_type): bicubic_exclude = attrs.bicubic_exclude out_dtype = attrs.out_dtype return [ - tvm.topi.image.resize( + tvm.topi.image.resize2d( inputs[0], inputs[1], layout, @@ -51,11 +51,11 @@ def compute_resize(attrs, inputs, out_type): ] -reg.register_injective_schedule("dyn.image.resize") +reg.register_injective_schedule("dyn.image.resize2d") @script -def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): +def _resize2d_shape_func(dshape, size, ndim, height_axis, width_axis): out = output_tensor((ndim,), "int64") for i in const_range(ndim): out[i] = int64(dshape[i]) @@ -64,15 +64,15 @@ def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): return out -@reg.register_shape_func("dyn.image.resize", True) -def resize_shape_func(attrs, inputs, _): +@reg.register_shape_func("dyn.image.resize2d", True) +def resize2d_shape_func(attrs, inputs, _): """ Shape function for dyn.image.resize op. """ layout = attrs.layout if nchw_pack_layout(layout) or nchw_xc_layout(layout): out = [ - _resize_shape_func( + _resize2d_shape_func( inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), convert(2), convert(3) ) ] @@ -84,7 +84,7 @@ def resize_shape_func(attrs, inputs, _): if letter == "W": width_axis = i out = [ - _resize_shape_func( + _resize2d_shape_func( inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 2071a43f828b..c0f8bb5011b8 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -26,13 +26,13 @@ from .. import op as reg from .. import strategy from ..op import OpPattern -from .image import resize +from .image import resize2d # resize -@reg.register_compute("image.resize") -def compute_resize(attrs, inputs, out_type): - """compute definition for resize op""" +@reg.register_compute("image.resize2d") +def compute_resize2d(attrs, inputs, out_type): + """compute definition for resize2d op""" size = attrs.size layout = attrs.layout method = attrs.method @@ -42,7 +42,7 @@ def compute_resize(attrs, inputs, out_type): bicubic_exclude = attrs.bicubic_exclude out_dtype = attrs.out_dtype return [ - topi.image.resize( + topi.image.resize2d( inputs[0], size, layout, @@ -56,12 +56,12 @@ def compute_resize(attrs, inputs, out_type): ] -reg.register_injective_schedule("image.resize") +reg.register_injective_schedule("image.resize2d") -@reg.register_convert_op_layout("image.resize") -def convert_image_resize(attrs, inputs, tinfos, desired_layouts): - """Convert Layout pass registration for image resize op. +@reg.register_convert_op_layout("image.resize2d") +def convert_image_resize2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize2d op. Parameters ---------- @@ -86,11 +86,11 @@ def convert_image_resize(attrs, inputs, tinfos, desired_layouts): desired_layout = str(desired_layouts[0]) assert desired_layout != "default", "Layout cannot be default" new_attrs["layout"] = desired_layout - return resize(*inputs, **new_attrs) + return resize2d(*inputs, **new_attrs) @script -def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): +def _resize2d_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") out[batch_axis] = int64(image_shape[0]) out[height_axis] = int64(size[0]) @@ -99,10 +99,10 @@ def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, c return out -@reg.register_shape_func("image.resize", False) -def resize_shape_func(attrs, inputs, _): +@reg.register_shape_func("image.resize2d", False) +def resize2d_shape_func(attrs, inputs, _): """ - Shape function for resize op. + Shape function for resize2d op. """ layout = attrs.layout height_axis = width_axis = channel_axis = 1 @@ -117,7 +117,7 @@ def resize_shape_func(attrs, inputs, _): channel_axis = i size = get_const_tuple(attrs.size) return [ - _resize_shape_func( + _resize2d_shape_func( inputs[0], convert(size), convert(batch_axis), diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 6d7d79264844..78507c795de8 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -20,7 +20,7 @@ from ...expr import Expr, Constant -def resize( +def resize2d( data, size, layout="NCHW", @@ -31,7 +31,7 @@ def resize( bicubic_exclude=0, out_dtype=None, ): - """Image resize operator. + """Image resize2d operator. This operator takes data as input and does 2D scaling to the given scale factor. In the default case, where the data_layout is `NCHW` @@ -82,7 +82,7 @@ def resize( if isinstance(size, Constant): size = list(size.data.numpy().astype("int32")) if isinstance(size, Expr): - return _dyn_make.resize( + return _dyn_make.resize2d( data, size, layout, @@ -93,7 +93,7 @@ def resize( bicubic_exclude, out_dtype, ) - return _make.resize( + return _make.resize2d( data, size, layout, diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 780badc89fc4..a7a6efc24b71 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -139,9 +139,14 @@ class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@tvm._ffi.register_object("relay.attrs.ResizeAttrs") -class ResizeAttrs(Attrs): - """Attributes for image.resize""" +@tvm._ffi.register_object("relay.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes for image.resize2d""" + + +@tvm._ffi.register_object("relay.attrs.Resize3DAttrs") +class Resize3DAttrs(Attrs): + """Attributes used in resize3d operators""" @tvm._ffi.register_object("relay.attrs.CropAndResizeAttrs") @@ -499,11 +504,6 @@ class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" -@tvm._ffi.register_object("relay.attrs.Resize3dAttrs") -class Resize3dAttrs(Attrs): - """Attributes used in resize3d operators""" - - @tvm._ffi.register_object("relay.attrs.ScatterAttrs") class ScatterAttrs(Attrs): """Attributes used in scatter operators""" diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 42d0455665a1..e49e929ac52d 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -615,7 +615,7 @@ def _cubic_kernel(inputs, w): return _cast_output(value, data.dtype, out_dtype=out_dtype) -def resize( +def resize2d( data, size, layout="NCHW", diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index b95835f6e103..0ebe96d3acdd 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -92,7 +92,7 @@ def upsampling( else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" - return topi.image.resize( + return topi.image.resize2d( data, reshape_size, layout=layout, diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index 87cf89a223ec..c438b1a04790 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -31,10 +31,10 @@ namespace tvm { namespace relay { namespace dyn { -TVM_REGISTER_NODE_TYPE(ResizeAttrs); +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); -bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // {data, size, out} ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -42,7 +42,7 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, static const Layout kNCHW("NCHW"); - const ResizeAttrs* param = attrs.as(); + const Resize2DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); @@ -66,10 +66,10 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - double bicubic_exclude, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize2D(Expr data, Expr size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, + double bicubic_alpha, double bicubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; @@ -77,13 +77,13 @@ Expr MakeResize(Expr data, Expr size, String layout, String method, attrs->bicubic_alpha = bicubic_alpha; attrs->bicubic_exclude = bicubic_exclude; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("dyn.image.resize"); + static const Op& op = Op::Get("dyn.image.resize2d"); return Call(op, {data, size}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize").set_body_typed(MakeResize); +TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize2d").set_body_typed(MakeResize2D); -RELAY_REGISTER_OP("dyn.image.resize") +RELAY_REGISTER_OP("dyn.image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape @@ -100,12 +100,12 @@ RELAY_REGISTER_OP("dyn.image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("size", "Tensor", "The output size tensor.") .set_support_level(5) - .add_type_rel("DynResize", ResizeRel) + .add_type_rel("DynResize2D", Resize2DRel) .set_attr("TOpPattern", kInjective); } // namespace dyn diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b672c7f87c05..e7626562b6a6 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -31,7 +31,7 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(ResizeAttrs); +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); template InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, @@ -58,15 +58,15 @@ InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params)); } -bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); - const ResizeAttrs* param = attrs.as(); + const Resize2DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); @@ -90,10 +90,10 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - int bicubic_exclude, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize2D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, + double bicubic_alpha, int bicubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -102,13 +102,13 @@ Expr MakeResize(Expr data, Array size, String layout, String method, attrs->bicubic_alpha = bicubic_alpha; attrs->bicubic_exclude = bicubic_exclude; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("image.resize"); + static const Op& op = Op::Get("image.resize2d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize); +TVM_REGISTER_GLOBAL("relay.op.image._make.resize2d").set_body_typed(MakeResize2D); -RELAY_REGISTER_OP("image.resize") +RELAY_REGISTER_OP("image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape @@ -122,17 +122,17 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) - .add_type_rel("Resize", ResizeRel) - .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) + .add_type_rel("Resize2D", Resize2DRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) .set_attr("TOpPattern", kInjective); -TVM_REGISTER_NODE_TYPE(Resize3dAttrs); +TVM_REGISTER_NODE_TYPE(Resize3DAttrs); -bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, +bool Resize3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -140,7 +140,7 @@ bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, static const Layout kNCDHW("NCDHW"); - const Resize3dAttrs* param = attrs.as(); + const Resize3DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); @@ -165,9 +165,9 @@ bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize3d(Expr data, Array size, String layout, String method, +Expr MakeResize3D(Expr data, Array size, String layout, String method, String coordinate_transformation_mode, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -177,7 +177,7 @@ Expr MakeResize3d(Expr data, Array size, String layout, String method return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3d); +TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3D); RELAY_REGISTER_OP("image.resize3d") .describe(R"code( @@ -194,11 +194,11 @@ Perform resize3d to input array with nearest neighbour or bilinear interpolation for layout NDHWC (batch_size, size[0], size[1], size[2], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) - .add_type_rel("Resize3d", Resize3dRel) + .add_type_rel("Resize3d", Resize3DRel) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 6f4db5ab268a..927fb1bdbbe9 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -101,9 +101,9 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); -Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - int bicubic_exclude, DataType out_dtype); +Expr MakeResize2D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, + double bicubic_alpha, int bicubic_exclude, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 7c947ba109bf..de3c61f2d573 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -106,20 +106,20 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, - {Op::Get("dyn.image.resize"), + {Op::Get("dyn.image.resize2d"), [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* size = args[1].as()) { - const ResizeAttrs* param = call_node->attrs.as(); + const Resize2DAttrs* param = call_node->attrs.as(); ICHECK(param); auto size_int = ToVector(size->data); Array size_prim; for (size_t i = 0; i < size_int.size(); ++i) { size_prim.push_back(size_int[i]); } - return MakeResize(call_node->args[0], size_prim, param->layout, param->method, - param->coordinate_transformation_mode, param->rounding_method, - param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); + return MakeResize2D(call_node->args[0], size_prim, param->layout, param->method, + param->coordinate_transformation_mode, param->rounding_method, + param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); } return Expr(nullptr); }}, From 02c251791d338a9416f643e390271e8c234b4dc9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 25 Jun 2021 12:55:30 -0600 Subject: [PATCH 02/12] refactor resize_2d --- python/tvm/relay/op/image/_image.py | 94 ++++ python/tvm/relay/op/image/image.py | 76 ++++ python/tvm/topi/image/resize.py | 645 ++++++++-------------------- python/tvm/topi/utils.py | 10 + 4 files changed, 363 insertions(+), 462 deletions(-) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index c0f8bb5011b8..e7c7996b7a1c 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -30,6 +30,100 @@ # resize +@reg.register_compute("image.resize1d") +def compute_resize1d(attrs, inputs, out_type): + """compute definition for resize1d op""" + size = attrs.size + layout = attrs.layout + method = attrs.method + coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + bicubic_alpha = attrs.bicubic_alpha + bicubic_exclude = attrs.bicubic_exclude + out_dtype = attrs.out_dtype + return [ + topi.image.resize1d( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, + ) + ] + + +reg.register_injective_schedule("image.resize1d") + + +@reg.register_convert_op_layout("image.resize1d") +def convert_image_resize1d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize1d op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current resize op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data input. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + new_attrs = dict(attrs) + assert len(desired_layouts) == 1, "Only one desired layout is expected" + desired_layout = str(desired_layouts[0]) + assert desired_layout != "default", "Layout cannot be default" + new_attrs["layout"] = desired_layout + return resize1d(*inputs, **new_attrs) + + +@script +def _resize1d_shape_func(image_shape, size, batch_axis, width_axis, channel_axis): + out = output_tensor((3,), "int64") + out[batch_axis] = int64(image_shape[0]) + out[width_axis] = int64(size[1]) + out[channel_axis] = image_shape[channel_axis] + return out + + +@reg.register_shape_func("image.resize1d", False) +def resize1d_shape_func(attrs, inputs, _): + """ + Shape function for resize2d op. + """ + layout = attrs.layout + width_axis = channel_axis = 1 + for i, letter in enumerate(layout): + if letter == "N": + batch_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + size = get_const_tuple(attrs.size) + return [ + _resize1d_shape_func( + inputs[0], + convert(size), + convert(batch_axis), + convert(width_axis), + convert(channel_axis), + ) + ] + + @reg.register_compute("image.resize2d") def compute_resize2d(attrs, inputs, out_type): """compute definition for resize2d op""" diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 78507c795de8..ab6b0df9850b 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -20,6 +20,82 @@ from ...expr import Expr, Constant +def resize1d( + data, + size, + layout="NCW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, + out_dtype=None, +): + """Image resize1d operator. + + This operator takes data as input and does 1D scaling to the given scale factor. + In the default case, where the data_layout is `NCW` + with data of shape (n, c, w) + out will have a shape (n, c, size[0]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size: Tuple of Int or Expr + The out size to which the image will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + [half_pixel, align_corners, asymmetric] + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + bicubic_alpha: float + Spline Coefficient for Bicubic Interpolation + + bicubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + + out_dtype : str, optional + Type to return. If left None returns the same type as input. + + Returns + ------- + result: relay.Expr + The resized result. + """ + if isinstance(size, Constant): + size = list(size.data.numpy().astype("int32")) + if isinstance(size, Expr): + raise NotImplementedError("dyn.resize1d is not yet implemented, got size", size) + return _make.resize1d( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, + ) + + def resize2d( data, size, diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index e49e929ac52d..87e3ff2ee54f 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -23,6 +23,25 @@ from .. import tag +def get_1d_indices(indices, layout="NCW"): + """Get 1d indices""" + (cc, inum, ic) = (0, 0, 0) + if layout == "NWC": + n, x, c = indices + cc = None + elif layout == "NCW": + n, c, x = indices + cc = None + elif ncw_pack_layout(layout): + n, c, x, inum, ic = indices + else: + # else must be NCHWxc + assert ncw_xc_layout(layout) + n, c, x, cc = indices + + return n, c, x, cc, inum, ic + + def get_2d_indices(indices, layout="NCHW"): """Get 2d indices""" (cc, inum, ic) = (0, 0, 0) @@ -42,6 +61,22 @@ def get_2d_indices(indices, layout="NCHW"): return n, c, y, x, cc, inum, ic +def get_1d_pixel(data, layout, boxes, image_width, n, c, x, cc, ib, ic): + """Get 1d pixel""" + if boxes is None: + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + if layout == "NWC": + return data(n, x, c).astype("float") + if layout == "NCW": + return data(n, c, x).astype("float") + if ncw_pack_layout(layout): + return data(n, c, x, ib, ic).astype("float") + + # else must be NCHWxc + assert ncw_xc_layout(layout) + return data(n, c, x, cc).astype("float") + + def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic): """Get 2d pixel""" if boxes is None: @@ -59,198 +94,76 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, return data(n, c, y, x, cc).astype("float") -def get_iny_inx( - y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode -): - """Infer input x,y from output x,y with various coordinate transformation methods""" - scale_y = te.div(image_height.astype("float"), target_height.astype("float")) +def get_inx(x, image_width, target_width, coordinate_transformation_mode): + """Infer input x from output x with various coordinate transformation methods""" scale_x = te.div(image_width.astype("float"), target_width.astype("float")) if coordinate_transformation_mode == "half_pixel": - in_y = (y + 0.5) * scale_y - 0.5 in_x = (x + 0.5) * scale_x - 0.5 elif coordinate_transformation_mode == "align_corners": - in_y = (image_height - 1).astype("float") / (target_height - 1) * y in_x = (image_width - 1).astype("float") / (target_width - 1) * x elif coordinate_transformation_mode == "asymmetric": - in_y = scale_y * y in_x = scale_x * x elif coordinate_transformation_mode == "pytorch_half_pixel": - in_y = te.if_then_else(target_height > 1, (y + 0.5) * scale_y - 0.5, 0.0) in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) elif coordinate_transformation_mode == "tf_half_pixel_for_nn": - in_y = (y + 0.5) * scale_y in_x = (x + 0.5) * scale_x else: raise ValueError( "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) ) - return in_y, in_x + return in_x -def resize_nearest_neighbor( - indices, - data, - image_height, - image_width, - target_height, - target_width, - boxes=None, - box_indices=None, - extrapolation_value=None, - layout="NCHW", - coordinate_transformation_mode="align_corners", - rounding_method="", - out_dtype=None, +def get_iny_inx( + y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode ): + """Infer input x,y from output x,y with various coordinate transformation methods""" + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + return in_y, in_x - """Perform resize operation with nearest neighbor method on the data. - For details about Nearest-neighbor interpolation please refer to - https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. - - Parameters - ---------- - indices : tuple - The indices of input data - - data : tvm.te.Tensor - inputs is a 4-D tensor with shape - [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] - - image_height : integer - Input image height - - image_width : integer - Input image width - - target_height : integer - The target resized image height - - target_width : integer - The target resized image width - - boxes : tvm.te.Tensor, optional - A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies - the coordinates of a box. - - box_indices : tvm.te.Tensor, optional - A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that - the i-th box refers to. - - extrapolation_value: float, optional - Value used for extrapolation, when applicable. - - layout: string, optional - "NCHW", "NHWC", or "NCHWc". - - coordinate_transformation_mode: string, optional - Describes how to transform the coordinate in the resized tensor - to the coordinate in the original tensor. - Refer to the ONNX Resize operator specification for details. - Available options are "half_pixel", "align_corners" and "asymmetric". - - rounding_method: string, optional - indicates how to find the "nearest" pixel in nearest_neighbor method - [round, floor, ceil] - - out_dtype: string, optional - Type to return. If left None will be same as input type. - - Returns - ------- - output : out_dtype - The computed result with type out_dtype - """ - if rounding_method == "": - if coordinate_transformation_mode == "align_corners": - rounding_method = "round" - else: - rounding_method = "floor" - - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) - - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) - box_idx = box_indices(n) if box_indices is not None else n - if boxes is not None: - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) - - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) - - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x - else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, - image_width, - target_height, - target_width, - coordinate_transformation_mode, - ) +def get_closest_index(in_x, rounding_method, boxes): if rounding_method == "round" or boxes is not None: closest_x_index = te.round(in_x).astype("int32") - closest_y_index = te.round(in_y).astype("int32") elif rounding_method == "round_prefer_floor": closest_x_index = te.ceil(in_x - 0.5).astype("int32") - closest_y_index = te.ceil(in_y - 0.5).astype("int32") elif rounding_method == "round_prefer_ceil": closest_x_index = te.floor(in_x + 0.5).astype("int32") - closest_y_index = te.floor(in_y + 0.5).astype("int32") elif rounding_method == "floor": # Add epsilon to floor to prevent gpu rounding errors. epsilon = 1e-5 - closest_y_index = te.floor(in_y + epsilon).astype("int32") closest_x_index = te.floor(in_x + epsilon).astype("int32") elif rounding_method == "ceil": # Subract epsilon from ceil to prevent gpu rounding errors. epsilon = 1e-5 - closest_y_index = te.ceil(in_y - epsilon).astype("int32") closest_x_index = te.ceil(in_x - epsilon).astype("int32") else: raise ValueError("Uknown rounding method: {}".format(rounding_method)) + return closest_x_index - value = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - closest_y_index, - closest_x_index, - cc, - inum, - ic, - ) - if extrapolation_value is not None: - out = tvm.tir.if_then_else( - in_y < 0, - extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), - ) - # use extrapolation_value if in_x is out of boundary - value = tvm.tir.if_then_else( - in_x < 0, - extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), - ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) +def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + +def _cubic_spline_weights(t, alpha): + """create cubic spline weights in 1D""" + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] -def resize_bilinear( +def _cubic_kernel(inputs, w): + """perform cubic interpolation in 1D""" + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + +def _resize_2d( indices, data, image_height, @@ -259,15 +172,17 @@ def resize_bilinear( target_width, boxes=None, box_indices=None, + method=None, extrapolation_value=None, layout="NCHW", coordinate_transformation_mode="align_corners", + rounding_method="", + alpha=-0.5, + exclude_outside=0, out_dtype=None, ): - """Perform resize operation with bilinear method on the data. - For details about Bilinear interpolation please refer to - https://en.wikipedia.org/wiki/Bilinear_interpolation. + """Perform resize operation on the data with selected method and options. Parameters ---------- @@ -311,6 +226,16 @@ def resize_bilinear( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + out_dtype: string, optional Type to return. If left None will be same as input type. @@ -327,12 +252,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): dtype = data_dtype return value.astype(dtype) - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout=layout) + n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n - if boxes is not None: y1, x1 = boxes(n, 0), boxes(n, 1) y2, x2 = boxes(n, 2), boxes(n, 3) @@ -355,258 +276,115 @@ def _lerp(A, B, t): coordinate_transformation_mode, ) - top_y_index = te.floor(in_y).astype("int32") - bottom_y_index = te.ceil(in_y).astype("int32") - y_lerp = in_y - top_y_index - - left_x_index = te.floor(in_x).astype("int32") - right_x_index = te.ceil(in_x).astype("int32") - x_lerp = in_x - left_x_index - - top_left = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - top_y_index, - left_x_index, - cc, - inum, - ic, - ) - top_right = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - top_y_index, - right_x_index, - cc, - inum, - ic, - ) - bottom_left = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - bottom_y_index, - left_x_index, - cc, - inum, - ic, - ) - bottom_right = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - bottom_y_index, - right_x_index, - cc, - inum, - ic, - ) - - top = _lerp(top_left, top_right, x_lerp) - bottom = _lerp(bottom_left, bottom_right, x_lerp) - value = _lerp(top, bottom, y_lerp) - - # use extrapolation_value if in_y/in_x is out of boundary - if extrapolation_value is not None: - out = tvm.tir.if_then_else( - in_y < 0, - extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), - ) - value = tvm.tir.if_then_else( - in_x < 0, - extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), - ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) - - -def resize_bicubic( - indices, - data, - image_height, - image_width, - target_height, - target_width, - boxes=None, - box_indices=None, - extrapolation_value=None, - layout="NCHW", - coordinate_transformation_mode="align_corners", - out_dtype=None, - alpha=-0.5, - exclude_outside=0, -): - """Perform resize operation with bicubic method on the data. - More details about Bicubic interpolation please refer to - https://en.wikipedia.org/wiki/Bicubic_interpolation. - This algorithm is doing a bicubic spline interpolation - - Parameters - ---------- - indices : tuple - The indices of input data - - data : tvm.te.Tensor - inputs is a 4-D tensor with shape - [batch, channel, in_height, in_width] - or [:batch, in_height, in_width, channel] - - image_height : integer - Input image height - - image_width : integer - Input image width - - target_height : integer - The target resized image height - - target_width : integer - The target resized image width - - boxes : tvm.te.Tensor, optional - A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies - the coordinates of a box. - - box_indices : tvm.te.Tensor, optional - A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that - the i-th box refers to. - - extrapolation_value: float, optional - Value used for extrapolation, when applicable. - - layout: string, optional - "NCHW", "NHWC", or "NCHWc". - - coordinate_transformation_mode: string, optional - Describes how to transform the coordinate in the resized tensor - to the coordinate in the original tensor. - Refer to the ONNX Resize operator specification for details. - Available options are "half_pixel", "align_corners" and "asymmetric". + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" - out_dtype: string, optional - Type to return. If left None will be same as input type. + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) - alpha: float, optional - Bicubic spline coefficient + value = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + closest_y_index, + closest_x_index, + cc, + inum, + ic, + ) + elif method == "bilinear": + y_int = te.floor(in_y).astype("int32") + x_int = te.floor(in_x).astype("int32") + + y_lerp = in_y - y_int + x_lerp = in_x - x_int + + p = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + p[j][i] = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + y_int + j, + x_int + i, + cc, + inum, + ic, + ) - Returns - ------- - output : out_dtype - The computed result with type out_dtype - """ + top = _lerp(*p[0], x_lerp) + bottom = _lerp(*p[1], x_lerp) + value = _lerp(top, bottom, y_lerp) - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) + elif method == "bicubic": + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) - box_idx = box_indices(n) if box_indices is not None else n + yint = te.floor(in_y).astype("int32") + yfract = in_y - te.floor(in_y) - if boxes is not None: - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) + # Get the surrounding values + p = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + p[j][i] = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + yint + j - 1, + xint + i - 1, + cc, + inum, + ic, + ) - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) + wx = _cubic_spline_weights(xfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + sum_wx = sum(wx) + sum_wy = sum(wy) + wx = [w / sum_wx for w in wx] + wy = [w / sum_wy for w in wy] + col0 = _cubic_kernel(p[0], wx) + col1 = _cubic_kernel(p[1], wx) + col2 = _cubic_kernel(p[2], wx) + col3 = _cubic_kernel(p[3], wx) + value = _cubic_kernel([col0, col1, col2, col3], wy) - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, - image_width, - target_height, - target_width, - coordinate_transformation_mode, - ) + raise ValueError("Unknown resize method:", method) - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) - - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) - - # Get the surrounding values - p = [[0 for i in range(4)] for j in range(4)] - for j in range(4): - for i in range(4): - p[j][i] = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - yint + j - 1, - xint + i - 1, - cc, - inum, - ic, - ) - - # Interpolate bicubically - def _cubic_spline_weights(t): - t2 = t * t - t3 = t * t * t - w1 = alpha * (t3 - 2 * t2 + t) - w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 - w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t - w4 = -alpha * t3 + alpha * t2 - return [w1, w2, w3, w4] - - def _cubic_kernel(inputs, w): - return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) - - wx = _cubic_spline_weights(xfract) - wy = _cubic_spline_weights(yfract) - if exclude_outside: - for i in range(4): - wx[i] = te.if_then_else(te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i]) - wy[i] = te.if_then_else(te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i]) - sum_wx = sum(wx) - sum_wy = sum(wy) - wx = [w / sum_wx for w in wx] - wy = [w / sum_wy for w in wy] - col0 = _cubic_kernel(p[0], wx) - col1 = _cubic_kernel(p[1], wx) - col2 = _cubic_kernel(p[2], wx) - col3 = _cubic_kernel(p[3], wx) - value = _cubic_kernel([col0, col1, col2, col3], wy) - - # use extrapolation_value if in_y/in_x is out of boundary if extrapolation_value is not None: out = tvm.tir.if_then_else( in_y < 0, extrapolation_value, tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), ) + # use extrapolation_value if in_x is out of boundary value = tvm.tir.if_then_else( in_x < 0, extrapolation_value, @@ -692,58 +470,23 @@ def resize2d( if isinstance(size[i], int): size[i] = tvm.tir.IntImm("int32", size[i]) - def _nearest_neighbor(*indices): - return resize_nearest_neighbor( + def compute_func(*indices): + return _resize_2d( indices, data, in_h, in_w, size[0], size[1], + method=method, layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, rounding_method=rounding_method, - out_dtype=out_dtype, - ) - - def _bilinear(*indices): - return resize_bilinear( - indices, - data, - in_h, - in_w, - size[0], - size[1], - layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, - out_dtype=out_dtype, - ) - - def _bicubic(*indices): - return resize_bicubic( - indices, - data, - in_h, - in_w, - size[0], - size[1], - layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, - out_dtype=out_dtype, alpha=bicubic_alpha, exclude_outside=bicubic_exclude, + out_dtype=out_dtype, ) - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "bilinear": - compute_func = _bilinear - elif method == "bicubic": - compute_func = _bicubic - else: - raise ValueError("%s method is not supported." % method) - return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) @@ -819,23 +562,8 @@ def crop_and_resize( else: raise ValueError("%s layout is not supported." % layout) - def _bilinear(*indices): - return resize_bilinear( - indices, - data, - image_h, - image_w, - target_h, - target_w, - boxes, - box_indices, - extrapolation_value, - layout, - out_dtype=out_dtype, - ) - - def _nearest_neighbor(*indices): - return resize_nearest_neighbor( + def compute_func(*indices): + return _resize_2d( indices, data, image_h, @@ -844,19 +572,12 @@ def _nearest_neighbor(*indices): target_w, boxes, box_indices, - extrapolation_value, - layout, + method=method, + extrapolation_value=extrapolation_value, + layout=layout, out_dtype=out_dtype, ) - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "bilinear": - compute_func = _bilinear - else: - raise ValueError("%s method is not supported." % method) - return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 3a056cfb4326..be3df2be5f6a 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -31,6 +31,16 @@ class InvalidShapeError(ValueError): """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" +def ncw_pack_layout(layout_info): + """Check whether the layout type is NCWinic""" + return layout_info[:3] == "NCW" and "c" in layout_info and "n" in layout_info + + +def ncw_xc_layout(layout_info): + """Check whether the layout type is NCWxc""" + return layout_info[:3] == "NCW" and "c" in layout_info and layout_info[3:-1].isnumeric() + + def nchw_pack_layout(layout_info): """Check whether the layout type is NCHWinic""" return layout_info[:4] == "NCHW" and "c" in layout_info and "n" in layout_info From 638ab02d9926ac6047f9f3af97f61f3a40bf66d2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 25 Jun 2021 14:48:02 -0600 Subject: [PATCH 03/12] Add resize1d op, normalize attribute names across ops --- include/tvm/relay/attrs/image.h | 99 +++++-- python/tvm/relay/frontend/onnx.py | 45 +++- python/tvm/relay/op/dyn/image/_image.py | 8 +- python/tvm/relay/op/image/_image.py | 17 +- python/tvm/relay/op/image/image.py | 38 +-- python/tvm/topi/image/resize.py | 287 ++++++++++++++++++++- src/relay/op/dyn/image/resize.cc | 8 +- src/relay/op/image/resize.cc | 85 +++++- src/relay/op/make_op.h | 4 +- src/relay/transforms/dynamic_to_static.cc | 2 +- tests/python/frontend/onnx/test_forward.py | 87 +++++-- 11 files changed, 572 insertions(+), 108 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 28ef3fb03aa3..b851add61e4a 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -32,15 +32,60 @@ namespace tvm { namespace relay { -/*! \brief Attributes used in image resize operator */ +/*! \brief Attributes used in image resize1d operator */ +struct Resize1DAttrs : public tvm::AttrsNode { + Array size; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + std::string rounding_method; + double cubic_alpha; + int cubic_exclude; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize1DAttrs, "relay.attrs.Resize1DAttrs") { + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel and width" + "dimensions respectively. Resize is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Linear Interpolation" + "cubic - Cubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for cubic interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during cubic interpolation"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + } +}; + +/*! \brief Attributes used in image resize2d operator */ struct Resize2DAttrs : public tvm::AttrsNode { Array size; std::string layout; std::string method; std::string coordinate_transformation_mode; std::string rounding_method; - double bicubic_alpha; - int bicubic_exclude; + double cubic_alpha; + int cubic_exclude; DataType out_dtype; TVM_DECLARE_ATTRS(Resize2DAttrs, "relay.attrs.Resize2DAttrs") { @@ -50,13 +95,11 @@ struct Resize2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Resize is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method) - .set_default("bilinear") - .describe( - "Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode) .set_default("half_pixel") .describe( @@ -69,10 +112,10 @@ struct Resize2DAttrs : public tvm::AttrsNode { .describe( "indicates how to find the \"nearest\" pixel in nearest_neighbor method" "Available options are round, floor, and ceil."); - TVM_ATTR_FIELD(bicubic_alpha) + TVM_ATTR_FIELD(cubic_alpha) .set_default(-0.5) .describe("Spline Coefficient for Bicubic Interpolation"); - TVM_ATTR_FIELD(bicubic_exclude) + TVM_ATTR_FIELD(cubic_exclude) .set_default(0) .describe("Flag to exclude exterior of the image during bicubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); @@ -82,9 +125,12 @@ struct Resize2DAttrs : public tvm::AttrsNode { /*! \brief Attributes used in image resize3d operator */ struct Resize3DAttrs : public tvm::AttrsNode { Array size; - String layout; - String method; - String coordinate_transformation_mode; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + std::string rounding_method; + double cubic_alpha; + int cubic_exclude; DataType out_dtype; TVM_DECLARE_ATTRS(Resize3DAttrs, "relay.attrs.Resize3DAttrs") { @@ -94,18 +140,29 @@ struct Resize3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Resize3d is applied on the 'D', 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method) - .set_default("trilinear") - .describe( - "Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Trilinear Interpolation" + "cubic - Tricubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode) .set_default("half_pixel") .describe( "Describes how to transform the coordinate in the resized tensor" "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for Tricubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during tricubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fd88cddb94ac..73303b28a9b2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2395,9 +2395,9 @@ def _impl_v10(cls, inputs, attr, params): if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": - method = "bilinear" + method = "linear" elif mode == "cubic": - method = "bicubic" + method = "cubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2406,23 +2406,28 @@ def _impl_v10(cls, inputs, attr, params): scale = inputs[1] size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale ndims = len(infer_shape(inputs[0])) - layout = {3: "NCW", 4: "NCHW", 5: "NCDHW"}[ndims] - out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize2d(inputs[0], out_size, layout, method, "asymmetric") + if ndims == 3: + out_size = fold_constant(_op.strided_slice(size, [2], [3])) + return _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric") + elif ndims == 4: + out_size = fold_constant(_op.strided_slice(size, [2], [4])) + return _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric") + elif ndims == 5: + out_size = fold_constant(_op.strided_slice(size, [2], [5])) + return _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric") + else: + raise NotImplementedError("Resize only supports 3, 4, or 5 dims") @classmethod def _impl_v11(cls, inputs, attr, params): - print({**attr}) ndims = len(infer_shape(inputs[0])) - layout = {3: "NCH", 4: "NCHW", 5: "NCDHW"}[ndims] - mode = attr.get("mode").decode("ascii") if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": - method = "bilinear" + method = "linear" elif mode == "cubic": - method = "bicubic" + method = "cubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2445,9 +2450,23 @@ def _impl_v11(cls, inputs, attr, params): size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize2d( - inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude - ) + if ndims == 3: + out_size = fold_constant(_op.strided_slice(size, [2], [3])) + return _op.image.resize1d( + inputs[0], out_size, "NCW", method, coord_trans, nearest_mode, alpha, exclude + ) + elif ndims == 4: + out_size = fold_constant(_op.strided_slice(size, [2], [4])) + return _op.image.resize2d( + inputs[0], out_size, "NCHW", method, coord_trans, nearest_mode, alpha, exclude + ) + elif ndims == 5: + out_size = fold_constant(_op.strided_slice(size, [2], [5])) + return _op.image.resize3d( + inputs[0], out_size, "NCDHW", method, coord_trans, nearest_mode, alpha, exclude + ) + else: + raise NotImplementedError("Resize only supports 3, 4, or 5 dims") class NonZero(OnnxOpConverter): diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 7f66a69b4803..5e97d2461100 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -32,8 +32,8 @@ def compute_resize2d(attrs, inputs, out_type): method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ tvm.topi.image.resize2d( @@ -43,8 +43,8 @@ def compute_resize2d(attrs, inputs, out_type): method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, out_type.shape, ) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index e7c7996b7a1c..aa6c95693dfd 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -38,8 +38,8 @@ def compute_resize1d(attrs, inputs, out_type): method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ topi.image.resize1d( @@ -49,8 +49,8 @@ def compute_resize1d(attrs, inputs, out_type): method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) ] @@ -132,8 +132,8 @@ def compute_resize2d(attrs, inputs, out_type): method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ topi.image.resize2d( @@ -143,8 +143,8 @@ def compute_resize2d(attrs, inputs, out_type): method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) ] @@ -168,7 +168,6 @@ def convert_image_resize2d(attrs, inputs, tinfos, desired_layouts): desired_layouts : list of layout strings List of layouts defining our desired layout for the data input. - Returns ------- result : tvm.relay.Expr diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index ab6b0df9850b..2f788bb0c23d 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -27,8 +27,8 @@ def resize1d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, - bicubic_exclude=0, + cubic_alpha=-0.5, + cubic_exclude=0, out_dtype=None, ): """Image resize1d operator. @@ -65,11 +65,11 @@ def resize1d( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] - bicubic_alpha: float - Spline Coefficient for Bicubic Interpolation + cubic_alpha: float + Spline Coefficient for cubic interpolation - bicubic_exclude: int - Flag to exclude exterior of the image during bicubic interpolation + cubic_exclude: int + Flag to exclude exterior of the image during cubic interpolation out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -90,8 +90,8 @@ def resize1d( method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) @@ -103,8 +103,8 @@ def resize2d( method="bilinear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, - bicubic_exclude=0, + cubic_alpha=-0.5, + cubic_exclude=0, out_dtype=None, ): """Image resize2d operator. @@ -115,7 +115,7 @@ def resize2d( out will have a shape (n, c, size[0], size[1]) method indicates the algorithm to be used while calculating the out value - and method can be one of ("bilinear", "nearest_neighbor", "bicubic") + and method can be one of ("linear", "nearest_neighbor", "cubic") Parameters ---------- @@ -129,7 +129,7 @@ def resize2d( Layout of the input. method : str, optional - Scale method to used [nearest_neighbor, bilinear, bicubic]. + Scale method to used [nearest_neighbor, linear, cubic]. coordinate_transformation_mode : string, optional Describes how to transform the coordinate in the resized tensor @@ -141,10 +141,10 @@ def resize2d( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] - bicubic_alpha: float - Spline Coefficient for Bicubic Interpolation + cubic_alpha: float + Spline Coefficient for bicubic interpolation - bicubic_exclude: int + cubic_exclude: int Flag to exclude exterior of the image during bicubic interpolation out_dtype : str, optional @@ -165,8 +165,8 @@ def resize2d( method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) return _make.resize2d( @@ -176,8 +176,8 @@ def resize2d( method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 87e3ff2ee54f..c7159df63d06 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -163,6 +163,281 @@ def _cubic_kernel(inputs, w): return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) +def _resize_1d( + indices, + data, + image_width, + target_width, + boxes=None, + box_indices=None, + method=None, + extrapolation_value=None, + layout="NCW", + coordinate_transformation_mode="align_corners", + rounding_method="", + alpha=-0.5, + exclude_outside=0, + out_dtype=None, +): + + """Perform resize operation on the data with selected method and options. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.te.Tensor + inputs is a 3-D tensor with shape + [batch, channel, in_width] + or [batch, in_width, channel] + + image_width : integer + Input image width + + target_width : integer + The target resized image width + + boxes : tvm.te.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.te.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCW", "NWC", or "NCWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) + + n, c, x, cc, inum, ic = get_1d_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + if boxes is not None: + # TODO(mbrookhart): Find an example of this + raise NotImplementedError("resize1d with image boxes not yet implemented") + else: + in_x = get_inx( + x, + image_width, + target_width, + coordinate_transformation_mode, + ) + + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" + + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + + value = get_1d_pixel( + data, + layout, + boxes, + image_width, + box_idx, + c, + closest_x_index, + cc, + inum, + ic, + ) + elif method == "linear": + x_int = te.floor(in_x).astype("int32") + + x_lerp = in_x - x_int + + p = [0 for i in range(2)] + for i in range(2): + p[i] = get_1d_pixel( + data, + layout, + boxes, + image_width, + box_idx, + c, + x_int + i, + cc, + inum, + ic, + ) + + value = _lerp(*p, x_lerp) + + elif method == "cubic": + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) + + # Get the surrounding values + p = [0 for i in range(4)] + for i in range(4): + p[i] = get_1d_pixel( + data, + layout, + boxes, + image_width, + box_idx, + c, + xint + i - 1, + cc, + inum, + ic, + ) + + wx = _cubic_spline_weights(xfract, alpha) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wx = sum(wx) + wx = [w / sum_wx for w in wx] + value = _cubic_kernel(p, wx) + + else: + raise ValueError("Unknown resize method:", method) + + if extrapolation_value is not None: + # use extrapolation_value if in_x is out of boundary + value = tvm.tir.if_then_else( + in_x < 0, + extrapolation_value, + tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, value), + ) + return _cast_output(value, data.dtype, out_dtype=out_dtype) + + +def resize1d( + data, + size, + layout="NCW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, + out_dtype=None, + output_shape=None, +): + """Perform resize operation on the data. + + Parameters + ---------- + data : tvm.te.Tensor + inputs is a 3-D tensor with shape + [batch, channel in_width] + or [batch in_width, channel] + + size: Tuple + Output resolution scale to + + layout: string, optional + "NCW", "NWC", or "NCWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + method: {"linear", "nearest_neighbor", "cubic"} + Method to be used for resizing. + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, chananel, in_width*scale] + or [batch, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_width*scale, channel-minor] + """ + method = method.lower() + if layout == "NWC": + in_n, in_w, in_c = data.shape + if output_shape is None: + output_shape = [in_n, size[0], in_c] + elif layout == "NCW": + in_n, in_c, in_w = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0]] + elif ncw_pack_layout(layout): # for NCWinic + in_n, in_c, in_w, in_inum, in_ic = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0], in_inum, in_ic] + elif ncw_xc_layout(layout): # for NCWxc + in_n, in_c, in_w, in_cc = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0], in_cc] + else: + raise ValueError("%s layout is not supported." % layout) + + if isinstance(size, tuple): + size = list(size) + + for i in range(1): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) + + def compute_func(*indices): + return _resize_1d( + indices, + data, + in_w, + size[0], + method=method, + layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, + out_dtype=out_dtype, + ) + + return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) + + def _resize_2d( indices, data, @@ -300,7 +575,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): inum, ic, ) - elif method == "bilinear": + elif method == "linear": y_int = te.floor(in_y).astype("int32") x_int = te.floor(in_x).astype("int32") @@ -329,7 +604,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): bottom = _lerp(*p[1], x_lerp) value = _lerp(top, bottom, y_lerp) - elif method == "bicubic": + elif method == "cubic": xint = te.floor(in_x).astype("int32") xfract = in_x - te.floor(in_x) @@ -397,7 +672,7 @@ def resize2d( data, size, layout="NCHW", - method="bilinear", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", bicubic_alpha=-0.5, @@ -426,7 +701,7 @@ def resize2d( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"bilinear", "nearest_neighbor", "bicubic"} + method: {"linear", "nearest_neighbor", "cubic"} Method to be used for resizing. out_dtype: string, optional @@ -561,6 +836,8 @@ def crop_and_resize( image_w = data.shape[3].astype("int32") else: raise ValueError("%s layout is not supported." % layout) + if method == "bilinear": + method = "linear" def compute_func(*indices): return _resize_2d( @@ -610,7 +887,7 @@ def resize3d( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"trilinear", "nearest_neighbor"} + method: {"linear", "nearest_neighbor"} Method to be used for resizing. out_dtype: string, optional diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index c438b1a04790..002105f4d565 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -67,15 +67,15 @@ bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize2D(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, - double bicubic_alpha, double bicubic_exclude, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + double cubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; - attrs->bicubic_alpha = bicubic_alpha; - attrs->bicubic_exclude = bicubic_exclude; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("dyn.image.resize2d"); return Call(op, {data, size}, Attrs(attrs), {}); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index e7626562b6a6..69c77ef0cfec 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -31,8 +31,6 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(Resize2DAttrs); - template InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, @@ -58,6 +56,81 @@ InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params)); } +TVM_REGISTER_NODE_TYPE(Resize1DAttrs); + +bool Resize1DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCW("NCW"); + + const Resize1DAttrs* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCW); + ICHECK(layout_converter.defined()) + << "Resize only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, param->size[0]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // assign output type + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); + return true; +} + +// Positional relay function to create image operator +// used by frontend FFI. +Expr MakeResize1D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); + attrs->size = std::move(size); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("image.resize1d"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.resize1d").set_body_typed(MakeResize1D); + +RELAY_REGISTER_OP("image.resize1d") + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 3D array of shape + (batch_size, channels, in_width) for NCW + (batch_size, in_width, channels) for NWC + +- **out**: Output is 3D array of shape + for layout NCW + (batch_size, channels, size[0]) + + for layout NWC + (batch_size, size[0], channels) +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize1D", Resize1DRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) + .set_attr("TOpPattern", kInjective); + +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); @@ -91,16 +164,16 @@ bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize2D(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, - double bicubic_alpha, int bicubic_exclude, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; - attrs->bicubic_alpha = bicubic_alpha; - attrs->bicubic_exclude = bicubic_exclude; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize2d"); return Call(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 927fb1bdbbe9..1a47193bb91a 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -102,8 +102,8 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); Expr MakeResize2D(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, - double bicubic_alpha, int bicubic_exclude, DataType out_dtype); + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index de3c61f2d573..318022fb86f5 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -119,7 +119,7 @@ class DynamicToStaticMutator : public MixedModeMutator { } return MakeResize2D(call_node->args[0], size_prim, param->layout, param->method, param->coordinate_transformation_mode, param->rounding_method, - param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); + param->cubic_alpha, param->cubic_exclude, param->out_dtype); } return Expr(nullptr); }}, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 52c3346e5807..7a76b14b6fce 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3548,7 +3548,7 @@ def test_gru(): @tvm.testing.uses_gpu def test_resize(): - def verify(ishape, oshape, scales, mode, coord_trans): + def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, exclude=False): nodes = [ make_constant_node("roi", onnx.TensorProto.FLOAT, (0,), []), make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), @@ -3566,6 +3566,8 @@ def verify(ishape, oshape, scales, mode, coord_trans): outputs=["Y"], mode=mode, coordinate_transformation_mode=coord_trans, + cubic_coeff_a=alpha, + exclude_outside=exclude, ) ) @@ -3582,29 +3584,66 @@ def verify(ishape, oshape, scales, mode, coord_trans): verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) - # upsampling - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "align_corners") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "half_pixel") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") - - # downsampling - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "align_corners") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "half_pixel") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") - - # scales are specified instead of sizes - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "linear", "asymmetric") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "align_corners") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "linear", "align_corners") - verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") - verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "nearest", "half_pixel") + for ndim in [1, 2]: + method = "nearest" + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method, coord_trans) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method, coord_trans) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method, coord_trans) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method, coord_trans) + + method = "linear" + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) + + if ndim == 2: + # ONNX Runtime only supports cubic interpolation for 2D images + method = "cubic" + for alpha in [0.5, 0.75]: + for exclude in [True, False]: + # upsampling + verify( + [1, 16] + [32] * ndim, + [1, 16] + [64] * ndim, + [], + method, + alpha=alpha, + exclude=exclude, + ) + # downsampling + verify( + [1, 16] + [32] * ndim, + [1, 16] + [16] * ndim, + [], + method, + alpha=alpha, + exclude=exclude, + ) + # scales are specified instead of sizes + verify( + [1, 16] + [32] * ndim, + [], + [1, 1] + [0.5] * ndim, + method, + alpha=alpha, + exclude=exclude, + ) + verify( + [1, 16] + [32] * ndim, + [], + [1, 1] + [2] * ndim, + method, + alpha=alpha, + exclude=exclude, + ) def verify_opset_10(ishape, scales, mode): nodes = [ From a592fd1e4dc99ce2505f1ff2992937fd72bf3594 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 25 Jun 2021 16:11:44 -0600 Subject: [PATCH 04/12] normalize resize3d to match the API of 1D and 2D --- python/tvm/relay/op/image/_image.py | 18 +- python/tvm/relay/op/image/image.py | 76 ++++ python/tvm/topi/image/resize.py | 442 +++++++++++++++------ src/relay/op/image/resize.cc | 6 +- tests/python/frontend/onnx/test_forward.py | 2 +- 5 files changed, 413 insertions(+), 131 deletions(-) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index aa6c95693dfd..ec24ff76b90e 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -223,12 +223,28 @@ def resize2d_shape_func(attrs, inputs, _): @reg.register_compute("image.resize3d") def compute_resize3d(attrs, inputs, out_type): + """compute definition for resize3d op""" size = attrs.size layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype - return [topi.image.resize3d(inputs[0], size, layout, method, coord_trans, out_dtype)] + return [ + topi.image.resize3d( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + ] reg.register_injective_schedule("image.resize3d") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 2f788bb0c23d..74edff140d6f 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -230,6 +230,82 @@ def resize3d( return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype) +def resize3d( + data, + size, + layout="NCDHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.5, + cubic_exclude=0, + out_dtype=None, +): + """Image resize3d operator. + + This operator takes data as input and does 3D scaling to the given scale factor. + In the default case, where the data_layout is `NCDHW` + with data of shape `(n, c, d, h, w)` + out will have a shape `(n, c, size[0], size[1], size[2])` + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size: Tuple of Int or Expr + The out size to which the image will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + [half_pixel, align_corners, asymmetric] + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for cubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during cubic interpolation + + out_dtype : str, optional + Type to return. If left None returns the same type as input. + + Returns + ------- + result: relay.Expr + The resized result. + """ + if isinstance(size, Constant): + size = list(size.data.numpy().astype("int32")) + if isinstance(size, Expr): + raise NotImplementedError("dyn.resize3d is not yet implemented, got size", size) + return _make.resize3d( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + + def crop_and_resize( data, boxes, diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index c7159df63d06..ae7d7bd02fbd 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -61,6 +61,19 @@ def get_2d_indices(indices, layout="NCHW"): return n, c, y, x, cc, inum, ic +def get_3d_indices(indices, layout="NCDHW"): + if layout == "NDHWC": + n, z, y, x, c = indices + cc = None + elif layout == "NCDHW": + n, c, z, y, x = indices + cc = None + else: + n, c, z, y, x, cc = indices + + return n, c, z, y, x, cc + + def get_1d_pixel(data, layout, boxes, image_width, n, c, x, cc, ib, ic): """Get 1d pixel""" if boxes is None: @@ -94,6 +107,19 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, return data(n, c, y, x, cc).astype("float") +def get_3d_pixel(data, layout, image_depth, image_height, image_width, n, c, z, y, x, cc): + """Get 3d pixel""" + z = tvm.te.max(tvm.te.min(z, image_depth - 1), 0) + y = tvm.te.max(tvm.te.min(y, image_height - 1), 0) + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + if layout == "NDHWC": + return data(n, z, y, x, c).astype("float") + if layout == "NCDHW": + return data(n, c, z, y, x).astype("float") + # else must be NCDHWxc + return data(n, c, z, y, x, cc).astype("float") + + def get_inx(x, image_width, target_width, coordinate_transformation_mode): """Infer input x from output x with various coordinate transformation methods""" scale_x = te.div(image_width.astype("float"), target_width.astype("float")) @@ -858,19 +884,272 @@ def compute_func(*indices): return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE) +def _resize_3d( + indices, + data, + image_depth, + image_height, + image_width, + target_depth, + target_height, + target_width, + boxes=None, + box_indices=None, + method=None, + extrapolation_value=None, + layout="NCHW", + coordinate_transformation_mode="align_corners", + rounding_method="", + alpha=-0.5, + exclude_outside=0, + out_dtype=None, +): + + """Perform resize operation on the data with selected method and options. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.te.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + image_depth : integer + Input image depth + + image_height : integer + Input image height + + image_width : integer + Input image width + + target_depth : integer + The target resized image depth + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.te.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.te.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) + + n, c, z, y, x, cc = get_3d_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + if boxes is not None: + # TODO(mbrookhart): Find an example of this + raise NotImplementedError("resize1d with image boxes not yet implemented") + else: + in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" + + closest_z_index = get_closest_index(in_z, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + + value = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + closest_z_index, + closest_y_index, + closest_x_index, + cc, + ) + elif method == "linear": + z_int = te.floor(in_z).astype("int32") + y_int = te.floor(in_y).astype("int32") + x_int = te.floor(in_x).astype("int32") + + z_lerp = in_z - z_int + y_lerp = in_y - y_int + x_lerp = in_x - x_int + + p = [[[0 for i in range(2)] for j in range(2)] for k in range(2)] + for k in range(2): + for j in range(2): + for i in range(2): + p[k][j][i] = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + z_int + k, + y_int + j, + x_int + i, + cc, + ) + l = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + l[j][i] = _lerp(*p[j][i], x_lerp) + + top = _lerp(*l[0], y_lerp) + bottom = _lerp(*l[1], y_lerp) + value = _lerp(top, bottom, z_lerp) + + elif method == "cubic": + zint = te.floor(in_z).astype("int32") + zfract = in_z - te.floor(in_z) + + yint = te.floor(in_y).astype("int32") + yfract = in_y - te.floor(in_y) + + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) + + # Get the surrounding values + p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] + for k in range(4): + for j in range(4): + for i in range(4): + p[k][j][i] = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + zint + k - 1, + yint + j - 1, + xint + i - 1, + cc, + ) + + wz = _cubic_spline_weights(zfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + wx = _cubic_spline_weights(xfract, alpha) + if exclude_outside: + for i in range(4): + wz[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wz = sum(wz) + sum_wy = sum(wy) + sum_wx = sum(wx) + wz = [w / sum_wz for w in wz] + wy = [w / sum_wy for w in wy] + wx = [w / sum_wx for w in wx] + + l = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + l[j][i] = _cubic_kerel(p[j][i], wx) + col0 = _cubic_kernel(l[0], wy) + col1 = _cubic_kernel(l[1], wy) + col2 = _cubic_kernel(l[2], wy) + col3 = _cubic_kernel(l[3], wy) + value = _cubic_kernel([col0, col1, col2, col3], wz) + + else: + raise ValueError("Unknown resize method:", method) + + if extrapolation_value is not None: + out = tvm.tir.if_then_else( + in_z < 0, + extrapolation_value, + tvm.tir.if_then_else(in_z > image_depth - 1, extrapolation_value, value), + ) + out = tvm.tir.if_then_else( + in_y < 0, + extrapolation_value, + tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), + ) + # use extrapolation_value if in_x is out of boundary + value = tvm.tir.if_then_else( + in_x < 0, + extrapolation_value, + tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + ) + return _cast_output(value, data.dtype, out_dtype=out_dtype) + + def resize3d( data, size, layout="NCDHW", - method="nearest_neighbor", - coordinate_transformation_mode="align_corners", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, + output_shape=None, ): """Perform resize operation on the data. Parameters ---------- - inputs: tvm.te.Tensor + data : tvm.te.Tensor inputs is a 5-D tensor with shape [batch, channel, in_depth, in_height, in_width] or [batch, in_depth, in_height, in_width, channel] @@ -885,24 +1164,26 @@ def resize3d( Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. Refer to the ONNX Resize operator specification for details. - Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"linear", "nearest_neighbor"} + + method: {"linear", "nearest_neighbor", "cubic"} Method to be used for resizing. out_dtype: string, optional Type to return. If left None will be same as input type. + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- output : tvm.te.Tensor - 5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] + 4-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] or [batch, in_depth*scale, in_height*scale, in_width*scale, channel] - or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, - channel-minor] + or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, channel-minor] """ method = method.lower() - if layout == "NDHWC": in_n, in_d, in_h, in_w, in_c = data.shape output_shape = [in_n, size[0], size[1], size[2], in_c] @@ -914,125 +1195,30 @@ def resize3d( in_n, in_c, in_d, in_h, in_w, in_cc = data.shape output_shape = [in_n, in_c, size[0], size[1], size[2], in_cc] - if coordinate_transformation_mode == "align_corners": - z_ratio = (in_d - 1).astype("float") / (size[0] - 1) - y_ratio = (in_h - 1).astype("float") / (size[1] - 1) - x_ratio = (in_w - 1).astype("float") / (size[2] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - z_ratio = (in_d).astype("float") / (size[0]) - y_ratio = (in_h).astype("float") / (size[1]) - x_ratio = (in_w).astype("float") / (size[2]) - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) - ) - - def _get_pixel(n, c, z, y, x, cc): - z = tvm.te.max(tvm.te.min(z, in_d - 1), 0) - y = tvm.te.max(tvm.te.min(y, in_h - 1), 0) - x = tvm.te.max(tvm.te.min(x, in_w - 1), 0) - if layout == "NDHWC": - return data(n, z, y, x, c).astype("float") - if layout == "NCDHW": - return data(n, c, z, y, x).astype("float") - # else must be NCDHWxc - return data(n, c, z, y, x, cc).astype("float") - - def _get_indices(*indices): - if layout == "NDHWC": - n, z, y, x, c = indices - cc = None - elif layout == "NCDHW": - n, c, z, y, x = indices - cc = None - else: - n, c, z, y, x, cc = indices - - return n, c, z, y, x, cc - - def _cast_output(value): - if out_dtype: - dtype = out_dtype - else: - dtype = data.dtype - return value.astype(dtype) - - # Nearest neighbor computation - def _nearest_neighbor(*indices): - n, c, z, y, x, cc = _get_indices(*indices) - - in_z = z_ratio * z - in_y = y_ratio * y - in_x = x_ratio * x - - if coordinate_transformation_mode == "align_corners": - zint = te.round(in_z).astype("int32") - yint = te.round(in_y).astype("int32") - xint = te.round(in_x).astype("int32") - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - zint = te.floor(in_z + epsilon).astype("int32") - yint = te.floor(in_y + epsilon).astype("int32") - xint = te.floor(in_x + epsilon).astype("int32") - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - return _cast_output(_get_pixel(n, c, zint, yint, xint, cc)) - - # Trilinear helper functions and computation. - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _trilinear(*indices): - n, c, z, y, x, cc = _get_indices(*indices) - - if coordinate_transformation_mode == "half_pixel": - in_z = z_ratio * (z + 0.5) - 0.5 - in_y = y_ratio * (y + 0.5) - 0.5 - in_x = x_ratio * (x + 0.5) - 0.5 - else: - in_z = z_ratio * z - in_y = y_ratio * y - in_x = x_ratio * x - - zint = te.floor(in_z).astype("int32") - zfract = in_z - te.floor(in_z) - - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + if isinstance(size, tuple): + size = list(size) - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + for i in range(3): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) - p000 = _get_pixel(n, c, zint, yint, xint, cc) - p001 = _get_pixel(n, c, zint, yint, xint + 1, cc) - p010 = _get_pixel(n, c, zint, yint + 1, xint, cc) - p011 = _get_pixel(n, c, zint, yint + 1, xint + 1, cc) - p100 = _get_pixel(n, c, zint + 1, yint, xint, cc) - p101 = _get_pixel(n, c, zint + 1, yint, xint + 1, cc) - p110 = _get_pixel(n, c, zint + 1, yint + 1, xint, cc) - p111 = _get_pixel(n, c, zint + 1, yint + 1, xint + 1, cc) - - dep00 = _lerp(p000, p100, zfract) - dep01 = _lerp(p001, p101, zfract) - dep10 = _lerp(p010, p110, zfract) - dep11 = _lerp(p011, p111, zfract) - col0 = _lerp(dep00, dep01, xfract) - col1 = _lerp(dep10, dep11, xfract) - value = _lerp(col0, col1, yfract) - return _cast_output(value) - - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "trilinear": - compute_func = _trilinear - else: - raise ValueError("%s method is not supported." % method) + def compute_func(*indices): + return _resize_3d( + indices, + data, + in_d, + in_h, + in_w, + size[0], + size[1], + size[2], + method=method, + layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, + out_dtype=out_dtype, + ) - return te.compute(output_shape, compute_func, name="resize3d", tag=tag.INJECTIVE) + return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 69c77ef0cfec..ee779841505c 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -239,12 +239,16 @@ bool Resize3DRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize3D(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize3d"); return Call(op, {data}, Attrs(attrs), {}); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7a76b14b6fce..16b855bc9b39 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3584,7 +3584,7 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) - for ndim in [1, 2]: + for ndim in [1, 2, 3]: method = "nearest" for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: # upsampling From 1b1cfe136c680e5a034bc0a4b1a7cc1892f86247 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 25 Jun 2021 18:50:27 -0600 Subject: [PATCH 05/12] fix lint --- python/tvm/relay/frontend/onnx.py | 18 ++++++----- python/tvm/relay/op/image/image.py | 48 ------------------------------ python/tvm/topi/image/resize.py | 25 ++++++++-------- 3 files changed, 24 insertions(+), 67 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 73303b28a9b2..79c185ddbd3c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2406,17 +2406,19 @@ def _impl_v10(cls, inputs, attr, params): scale = inputs[1] size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale ndims = len(infer_shape(inputs[0])) + out = None if ndims == 3: out_size = fold_constant(_op.strided_slice(size, [2], [3])) - return _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric") + out = _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric") elif ndims == 4: out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric") + out = _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric") elif ndims == 5: out_size = fold_constant(_op.strided_slice(size, [2], [5])) - return _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric") + out = _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric") else: raise NotImplementedError("Resize only supports 3, 4, or 5 dims") + return out @classmethod def _impl_v11(cls, inputs, attr, params): @@ -2449,25 +2451,27 @@ def _impl_v11(cls, inputs, attr, params): assert len(scale_shape) != 0, "One of scale or size should be passed." size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) - + out = None if ndims == 3: out_size = fold_constant(_op.strided_slice(size, [2], [3])) - return _op.image.resize1d( + out = _op.image.resize1d( inputs[0], out_size, "NCW", method, coord_trans, nearest_mode, alpha, exclude ) elif ndims == 4: out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize2d( + out = _op.image.resize2d( inputs[0], out_size, "NCHW", method, coord_trans, nearest_mode, alpha, exclude ) elif ndims == 5: out_size = fold_constant(_op.strided_slice(size, [2], [5])) - return _op.image.resize3d( + out = _op.image.resize3d( inputs[0], out_size, "NCDHW", method, coord_trans, nearest_mode, alpha, exclude ) else: raise NotImplementedError("Resize only supports 3, 4, or 5 dims") + return out + class NonZero(OnnxOpConverter): """Operator converter for NonZero""" diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 74edff140d6f..5ba57454924c 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -182,54 +182,6 @@ def resize2d( ) -def resize3d( - data, - size, - layout="NCDHW", - method="trilinear", - coordinate_transformation_mode="half_pixel", - out_dtype=None, -): - """Image resize 3D operator. - - This operator takes data as input and does 3D scaling to the given scale factor. - In the default case, where the data_layout is `NCDHW` - with data of shape `(n, c, d, h, w)` - out will have a shape `(n, c, size[0], size[1], size[2])` - - method indicates the algorithm to be used while calculating the out value - and method can be one of ("trilinear", "nearest_neighbor") - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - size: Tuple of Expr - The out size to which the image will be resized. - - layout : str, optional - Layout of the input. - - method : str, optional - Scale method to used [nearest_neighbor, trilinear]. - - coordinate_transformation_mode : string, optional - Describes how to transform the coordinate in the resized tensor - to the coordinate in the original tensor. - [half_pixel, align_corners, asymmetric] - - out_dtype : str, optional - Type to return. If left None returns the same type as input. - - Returns - ------- - result: relay.Expr - The resized result. - """ - return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype) - - def resize3d( data, size, diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index ae7d7bd02fbd..f32cdce3b07d 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -62,6 +62,7 @@ def get_2d_indices(indices, layout="NCHW"): def get_3d_indices(indices, layout="NCDHW"): + """Get 3d indices""" if layout == "NDHWC": n, z, y, x, c = indices cc = None @@ -150,6 +151,7 @@ def get_iny_inx( def get_closest_index(in_x, rounding_method, boxes): + """get the closest index to a value based on a certain rounding method""" if rounding_method == "round" or boxes is not None: closest_x_index = te.round(in_x).astype("int32") elif rounding_method == "round_prefer_floor": @@ -275,13 +277,12 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): if boxes is not None: # TODO(mbrookhart): Find an example of this raise NotImplementedError("resize1d with image boxes not yet implemented") - else: - in_x = get_inx( - x, - image_width, - target_width, - coordinate_transformation_mode, - ) + in_x = get_inx( + x, + image_width, + target_width, + coordinate_transformation_mode, + ) if method == "nearest_neighbor": if rounding_method == "": @@ -986,10 +987,9 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): if boxes is not None: # TODO(mbrookhart): Find an example of this raise NotImplementedError("resize1d with image boxes not yet implemented") - else: - in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode) - in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) - in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) if method == "nearest_neighbor": if rounding_method == "": @@ -1181,7 +1181,8 @@ def resize3d( output : tvm.te.Tensor 4-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] or [batch, in_depth*scale, in_height*scale, in_width*scale, channel] - or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, channel-minor] + or 5-D with shape + [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, channel-minor] """ method = method.lower() if layout == "NDHWC": From ce3c0114542fe4ed9ea60d65aff240d29dde1ce5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 28 Jun 2021 09:27:22 -0600 Subject: [PATCH 06/12] fix relay tests from API change --- python/tvm/relay/frontend/onnx.py | 18 ++++++++--------- python/tvm/relay/op/image/image.py | 2 +- python/tvm/relay/op/op_attrs.py | 5 +++++ python/tvm/topi/image/resize.py | 1 + python/tvm/topi/nn/upsampling.py | 4 ++++ tests/python/frontend/onnx/test_forward.py | 2 +- .../relay/dyn/test_dynamic_op_level2.py | 5 ++++- .../relay/dyn/test_dynamic_op_level5.py | 18 ++++++++--------- tests/python/relay/test_any.py | 8 ++++---- tests/python/relay/test_op_level2.py | 9 +++++++-- tests/python/relay/test_op_level5.py | 18 +++++++++-------- .../relay/test_pass_convert_op_layout.py | 20 +++++++++---------- .../relay/test_pass_dynamic_to_static.py | 10 +++++----- .../topi/python/test_topi_upsampling.py | 4 ++-- 14 files changed, 72 insertions(+), 52 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 79c185ddbd3c..c3108ff890b1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -457,7 +457,7 @@ def _impl_v1(cls, inputs, attr, params): kernel_type = infer_type(inputs[1]) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] - print(input_shape, kernel_shapes) + if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] @@ -1200,7 +1200,13 @@ def _impl_v9(cls, inputs, attr, params): layout = "NCDHW" out = _op.nn.upsampling3d( - inputs[0], scale_d, scale_h, scale_w, layout=layout, method=method + inputs[0], + scale_d, + scale_h, + scale_w, + layout=layout, + method=method, + coordinate_transformation_mode="asymmetric", ) # in 2d case, use dynamic op else: @@ -1365,13 +1371,7 @@ def _impl_v10(cls, inputs, attr, params): ends = inputs[2] axes = inputs[3] steps = inputs[4] - print("----------Slice------------") - print(inputs[0]) - print(inputs[1]) - print(inputs[2]) - print(inputs[3]) - print(inputs[4]) - print("----------/Slice------------") + ishape = infer_shape(inputs[0]) data_rank = len(ishape) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 5ba57454924c..7f5bd80159f9 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -100,7 +100,7 @@ def resize2d( data, size, layout="NCHW", - method="bilinear", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", cubic_alpha=-0.5, diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index a7a6efc24b71..2d185bcee798 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -139,6 +139,11 @@ class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" +@tvm._ffi.register_object("relay.attrs.Resize1DAttrs") +class Resize1DAttrs(Attrs): + """Attributes for image.resize1d""" + + @tvm._ffi.register_object("relay.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes for image.resize2d""" diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index f32cdce3b07d..65d3870c9597 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -1184,6 +1184,7 @@ def resize3d( or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, channel-minor] """ + method = method.lower() if layout == "NDHWC": in_n, in_d, in_h, in_w, in_c = data.shape diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index 0ebe96d3acdd..36b9349a139d 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -92,6 +92,8 @@ def upsampling( else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" + if method[0:2] == "bi": + method = method[2:] return topi.image.resize2d( data, reshape_size, @@ -188,6 +190,8 @@ def upsampling3d( ) else: raise ValueError("not support this layout {} yet".format(layout)) + if method[0:3] == "tri": + method = method[3:] return topi.image.resize3d( data, resize_shape, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 16b855bc9b39..6fa6c97d76c0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1370,7 +1370,7 @@ def verify_upsample3d_trilinear(): in_array, (3 * scale, 3 * scale, 3 * scale), "NCDHW", - coordinate_transformation_mode="half_pixel", + coordinate_transformation_mode="asymmetric", ) ref_array = np.array(scales) diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index dca5dd6d4384..8e0a11a6c9d2 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -87,7 +87,7 @@ def test_dyn_upsampling_infer_type_const(): @tvm.testing.uses_gpu def test_dyn_upsampling3d_run(): def verify_upsampling3d( - dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel" + dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="asymmetric" ): if layout == "NCDHW": @@ -99,6 +99,9 @@ def verify_upsampling3d( x_data = np.random.uniform(size=(n, d, h, w, c)).astype("float32") if method == "nearest_neighbor": + assert ( + coord_trans == "asymmetric" + ), "topi reference only support asymmetric nearest neighbor" ref_res = tvm.topi.testing.upsampling3d_python( x_data, (scale_d, scale_h, scale_w), layout ) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 78e2c232c08e..305fe4d2380f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -27,25 +27,25 @@ import tvm.testing -def test_resize_infer_type(): +def test_resize2d_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) size = relay.var("size", relay.TensorType((2,), "int8")) - z = relay.image.resize(x, size) + z = relay.image.resize2d(x, size) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") @tvm.testing.uses_gpu -def test_resize(): - def verify_resize(dshape, scale, method, layout): +def test_resize2d(): + def verify_resize2d(dshape, scale, method, layout): if layout == "NHWC": size = (dshape[1] * scale, dshape[2] * scale) else: size = (dshape[2] * scale, dshape[3] * scale) size = np.array(size).astype("int64") x_data = np.random.uniform(size=dshape).astype("float32") - if method == "bilinear": + if method == "linear": ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) else: ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) @@ -53,7 +53,7 @@ def verify_resize(dshape, scale, method, layout): size_var = relay.var("size", relay.TensorType((2,), "int64")) coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners" - z = relay.image.resize( + z = relay.image.resize2d( x, size_var, layout, method, coordinate_transformation_mode=coord_trans ) @@ -67,10 +67,10 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate()(x_data, size) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["bilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: - verify_resize((1, 4, 4, 4), 2, method, layout) - verify_resize((2, 8, 17, 20), 7, method, layout) + verify_resize2d((1, 4, 4, 4), 2, method, layout) + verify_resize2d((2, 8, 17, 20), 7, method, layout) if __name__ == "__main__": diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 13f5525bfee8..e94b5145ccc2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1275,7 +1275,7 @@ def test_any_ndarray_size(): verify_any_ndarray_size((1, 2, 3, 4)) -def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape): +def verify_any_resize2d(data_shape, scale, layout, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) @@ -1283,7 +1283,7 @@ def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shap size = (data_shape[1] * scale, data_shape[2] * scale) else: size = (data_shape[2] * scale, data_shape[3] * scale) - y = relay.image.resize(data, size, layout) + y = relay.image.resize2d(data, size, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) check_result([data_np], mod, ref_out_shape, assert_shape=True) @@ -1291,14 +1291,14 @@ def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shap @tvm.testing.uses_gpu def test_any_resize(): - verify_any_resize( + verify_any_resize2d( data_shape=(relay.Any(), 4, 4, 4), scale=2, layout="NHWC", static_data_shape=(1, 4, 4, 4), ref_out_shape=(1, 8, 8, 4), ) - verify_any_resize( + verify_any_resize2d( data_shape=(relay.Any(), 8, 17, 20), scale=3, layout="NCHW", diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 50fc0622ee6e..907ce79d82f4 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1448,6 +1448,7 @@ def get_shape(): align_corners=align_corners, ) func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) if method == "nearest_neighbor": ref = tvm.topi.testing.upsampling_python(data, (scale_h, scale_w), layout) @@ -1518,8 +1519,12 @@ def get_shape(): coordinate_transformation_mode=coordinate_transformation_mode, ) func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) if method == "nearest_neighbor": + assert ( + coordinate_transformation_mode == "asymmetric" + ), "topi reference only support asymmetric nearest neighbor" ref = tvm.topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout) else: ref = tvm.topi.testing.trilinear_resize3d_python( @@ -1535,9 +1540,9 @@ def get_shape(): @tvm.testing.uses_gpu def test_upsampling3d(): - _test_upsampling3d("NCDHW", "nearest_neighbor") + _test_upsampling3d("NCDHW", "nearest_neighbor", "asymmetric") _test_upsampling3d("NCDHW", "trilinear", "align_corners") - _test_upsampling3d("NDHWC", "nearest_neighbor") + _test_upsampling3d("NDHWC", "nearest_neighbor", "asymmetric") _test_upsampling3d("NDHWC", "trilinear", "align_corners") diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e27520339f36..9b12b9d12535 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -30,12 +30,12 @@ def test_resize_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) th, tw = te.var("th"), te.var("tw") - z = relay.image.resize(x, (th, tw)) + z = relay.image.resize2d(x, (th, tw)) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - z = relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners") + z = relay.image.resize2d(x, (100, 200), "NCHW", "linear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @@ -51,12 +51,14 @@ def verify_resize(dshape, scale, method, layout, coord_trans): x_data = np.random.uniform(size=dshape).astype("float32") - if method == "bilinear": + if method == "linear": ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout, coord_trans) else: ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.image.resize(x, size, layout, method, coordinate_transformation_mode=coord_trans) + z = relay.image.resize2d( + x, size, layout, method, coordinate_transformation_mode=coord_trans + ) assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") @@ -68,7 +70,7 @@ def verify_resize(dshape, scale, method, layout, coord_trans): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) - for method in ["nearest_neighbor", "bilinear"]: + for method in ["nearest_neighbor", "linear"]: for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) @@ -92,7 +94,7 @@ def test_resize3d_infer_type(): assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) - z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners") + z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "linear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8") @@ -107,7 +109,7 @@ def verify_resize(dshape, scale, method, layout): size = (dshape[2] * scale, dshape[3] * scale, dshape[4] * scale) x_data = np.random.uniform(size=dshape).astype("float32") - if method == "trilinear": + if method == "linear": ref_res = tvm.topi.testing.trilinear_resize3d_python(x_data, size, layout) else: ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale, scale, scale), layout) @@ -123,7 +125,7 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["trilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NDHWC", "NCDHW"]: verify_resize((1, 4, 4, 4, 4), 2, method, layout) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 88590c946e88..fafab3ee3584 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1797,24 +1797,24 @@ def expected(): _test_conv_reduce_convert_layout2() -def test_image_resize_convert_layout(): +def test_image_resize2d_convert_layout(): def _test_image_resize_convert_layout_nchw_to_nhwc(): def before(): x = relay.var("x", shape=(1, 2, 4, 4)) - y = relay.image.resize(x, (8, 8)) + y = relay.image.resize2d(x, (8, 8)) y = relay.Function([x], y) return y def expected(): x = relay.var("x", shape=(1, 2, 4, 4)) x = relay.layout_transform(x, "NCHW", "NHWC") - y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.image.resize2d(x, (8, 8), layout="NHWC") y = relay.layout_transform(y, "NHWC", "NCHW") y = relay.Function(relay.analysis.free_vars(y), y) return y a = before() - a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]})) + a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NHWC"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -1822,20 +1822,20 @@ def expected(): def _test_image_resize_convert_layout_nhwc_to_nchw(): def before(): x = relay.var("x", shape=(1, 4, 4, 2)) - y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.image.resize2d(x, (8, 8), layout="NHWC") y = relay.Function([x], y) return y def expected(): x = relay.var("x", shape=(1, 4, 4, 2)) x = relay.layout_transform(x, "NHWC", "NCHW") - y = relay.image.resize(x, (8, 8), layout="NCHW") + y = relay.image.resize2d(x, (8, 8), layout="NCHW") y = relay.layout_transform(y, "NCHW", "NHWC") y = relay.Function(relay.analysis.free_vars(y), y) return y a = before() - a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]})) + a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NCHW"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -1844,7 +1844,7 @@ def expected(): _test_image_resize_convert_layout_nhwc_to_nchw() -def test_conv_image_resize_convert_layout(): +def test_conv_image_resize2d_convert_layout(): """Check that layout transforms are propagated through image resize.""" def before(): @@ -1859,7 +1859,7 @@ def before(): data_layout="NHWC", kernel_layout="HWIO", ) - y = relay.image.resize(y, (112, 112), layout="NHWC") + y = relay.image.resize2d(y, (112, 112), layout="NHWC") y = relay.Function(analysis.free_vars(y), y) return y @@ -1869,7 +1869,7 @@ def expected(): x = relay.layout_transform(x, "NHWC", "NCHW") w = relay.layout_transform(w, "HWIO", "OIHW") y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) - y = relay.image.resize(y, (112, 112), layout="NCHW") + y = relay.image.resize2d(y, (112, 112), layout="NCHW") y = relay.layout_transform(y, "NCHW", "NHWC") y = relay.Function(analysis.free_vars(y), y) return y diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 9f7f3deebeb8..09f4b43449d7 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -248,7 +248,7 @@ def verify_ones_zeros(shape, dtype): @tvm.testing.uses_gpu -def test_dynamic_to_static_resize(): +def test_dynamic_to_static_resize2d(): def verify_resize(shape, scale, method, layout): if layout == "NHWC": size = (shape[1] * scale, shape[2] * scale) @@ -258,7 +258,7 @@ def verify_resize(shape, scale, method, layout): x = relay.var("x", relay.TensorType(shape, "float32")) size_var = relay.const(np.array(size).astype("float32")) coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners" - z = relay.image.resize( + z = relay.image.resize2d( x, size_var, layout, method, coordinate_transformation_mode=coord_trans ) @@ -267,17 +267,17 @@ def verify_resize(shape, scale, method, layout): zz = func2.body assert isinstance(zz, relay.Call) - assert zz.op == relay.op.get("image.resize") + assert zz.op == relay.op.get("image.resize2d") x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") - if method == "bilinear": + if method == "linear": ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) else: ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) verify_func(func2, [x_data], ref_res, rtol=1e-4, atol=1e-6) - for method in ["bilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: verify_resize((1, 4, 4, 4), 2, method, layout) diff --git a/tests/python/topi/python/test_topi_upsampling.py b/tests/python/topi/python/test_topi_upsampling.py index 0ab0e64af4c7..8dfe7d7a24b8 100644 --- a/tests/python/topi/python/test_topi_upsampling.py +++ b/tests/python/topi/python/test_topi_upsampling.py @@ -213,7 +213,7 @@ def verify_upsampling3d( scale_w, layout=layout, method=method, - coordinate_transformation_mode="half_pixel", + coordinate_transformation_mode="asymmetric", ) if method == "trilinear": @@ -223,7 +223,7 @@ def verify_upsampling3d( int(round(in_width * scale_w)), ) b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, out_size, layout, coordinate_transformation_mode="half_pixel" + a_np, out_size, layout, coordinate_transformation_mode="asymmetric" ) else: b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) From e2e6ea68167d8ade32cc3d8920650748fcbd39df Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 29 Jun 2021 10:14:33 -0600 Subject: [PATCH 07/12] refactor topi tests, docs --- docs/langref/relay_op.rst | 4 ++- python/tvm/topi/image/resize.py | 21 ++---------- tests/python/topi/python/test_topi_image.py | 38 ++++++++++----------- 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index febe542b83b1..3e797fc93b31 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -181,7 +181,9 @@ This level enables additional math and transform operators. .. autosummary:: :nosignatures: - tvm.relay.image.resize + tvm.relay.image.resize1d + tvm.relay.image.resize2d + tvm.relay.image.resize3d tvm.relay.image.crop_and_resize tvm.relay.image.dilation2d tvm.relay.vision.multibox_prior diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 65d3870c9597..e8d071ae73b2 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -141,15 +141,6 @@ def get_inx(x, image_width, target_width, coordinate_transformation_mode): return in_x -def get_iny_inx( - y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode -): - """Infer input x,y from output x,y with various coordinate transformation methods""" - in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) - in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) - return in_y, in_x - - def get_closest_index(in_x, rounding_method, boxes): """get the closest index to a value based on a certain rounding method""" if rounding_method == "round" or boxes is not None: @@ -172,6 +163,7 @@ def get_closest_index(in_x, rounding_method, boxes): def _lerp(A, B, t): + """Perform Linear interpolation in 1D""" return A * (1.0 - t) + B * t @@ -568,15 +560,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, - image_width, - target_height, - target_width, - coordinate_transformation_mode, - ) + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) if method == "nearest_neighbor": if rounding_method == "": diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 2730783907fd..381cdc08d890 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -24,7 +24,7 @@ from tvm.contrib.pickle_memoize import memoize -def verify_resize( +def verify_resize2d( batch, in_channel, in_height, @@ -33,7 +33,7 @@ def verify_resize( out_width, layout="NCHW", coord_trans="align_corners", - method="bilinear", + method="linear", ): if layout == "NCHW": A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="float32") @@ -47,14 +47,14 @@ def verify_resize( a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: raise NotImplementedError("Layout not supported {} ".format(layout)) - B = topi.image.resize( + B = topi.image.resize2d( A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method, ) - if method == "bilinear": + if method == "linear": b_np = tvm.topi.testing.bilinear_resize_python( a_np, (out_height, out_width), layout, coord_trans ) @@ -82,19 +82,19 @@ def check_target(target, dev): @tvm.testing.uses_gpu -def test_resize(): +def test_resize2d(): # Scale NCHW - verify_resize(4, 16, 32, 32, 50, 50, "NCHW") + verify_resize2d(4, 16, 32, 32, 50, 50, "NCHW") # Scale NCHW + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NCHW") + verify_resize2d(6, 32, 64, 64, 20, 20, "NCHW") # Scale NHWC - verify_resize(4, 16, 32, 32, 50, 50, "NHWC") + verify_resize2d(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NHWC") - for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric - for layout in ["NCHW", "NHWC"]: - verify_resize(4, 16, 32, 32, 50, 50, layout, coord_trans, method=method) + verify_resize2d(6, 32, 64, 64, 20, 20, "NHWC") + for layout in ["NCHW", "NHWC"]: + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="linear") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="linear") def verify_resize3d( @@ -107,8 +107,8 @@ def verify_resize3d( out_height, out_width, layout="NCDHW", - coordinate_transformation_mode="half_pixel", - method="trilinear", + coordinate_transformation_mode="asymmetric", + method="linear", ): if layout == "NCDHW": A = te.placeholder( @@ -139,7 +139,7 @@ def verify_resize3d( method=method, ) - if method == "trilinear": + if method == "linear": b_np = tvm.topi.testing.trilinear_resize3d_python( a_np, (out_depth, out_height, out_width), layout, coordinate_transformation_mode ) @@ -150,7 +150,6 @@ def verify_resize3d( b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) def check_target(target, dev): - print("Running on target: %s" % target) with tvm.target.Target(target): s = tvm.topi.testing.get_injective_schedule(target)(B) a = tvm.nd.array(a_np, dev) @@ -167,14 +166,15 @@ def check_target(target, dev): @tvm.testing.uses_gpu def test_resize3d(): # Trilinear - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW") - verify_resize3d(1, 8, 16, 16, 16, 25, 25, 25, "NDHWC") + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "half_pixel") + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "half_pixel") verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "align_corners") verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "align_corners") verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "asymmetric") verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "asymmetric") # Nearest neighbor + # Test kernel only supports asymmetric verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW", method="nearest_neighbor") verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NDHWC", method="nearest_neighbor") From c1158a54a4cf39105913c9545c16f246af9c80e6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 29 Jun 2021 14:43:58 -0600 Subject: [PATCH 08/12] fix method naming in framework frontends fix more frontend issues --- python/tvm/relay/frontend/keras.py | 1 + python/tvm/relay/frontend/pytorch.py | 12 +++++++----- python/tvm/relay/frontend/tensorflow_ops.py | 6 +++--- python/tvm/relay/frontend/tflite.py | 4 ++-- tests/python/frontend/onnx/test_forward.py | 19 +++++++++++-------- 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 63521a67b065..aa185923d02e 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -725,6 +725,7 @@ def _convert_upsample3d(inexpr, keras_layer, etab): params["scale_h"] = h params["scale_w"] = w params["layout"] = etab.data_layout + params["coordinate_transformation_mode"] = "asymmetric" out = _op.nn.upsampling3d(inexpr, **params) return out diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5c252739f190..72fd716b3fa8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1798,7 +1798,7 @@ def get_upsample_out_size(self, inputs, method): else: out_size.append(size) else: - scale_index = 3 if method in ["bilinear", "trilinear"] else 2 + scale_index = 3 if method == "linear" else 2 scales = inputs[scale_index] assert scales is not None, "neither out size nor scale provided" assert isinstance(scales, list) @@ -1813,7 +1813,7 @@ def upsample(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "bilinear": + if len(inputs) > 2 and method == "linear": align_corners = inputs[2] else: align_corners = False @@ -1845,7 +1845,7 @@ def upsample3d(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "trilinear": + if len(inputs) > 2 and method == "linear": align_corners = inputs[2] else: align_corners = False @@ -2195,6 +2195,8 @@ def interpolate(self, inputs, input_types): method = inputs[3] if method.startswith("nearest"): method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] if method == "nearest_neighbor": coord_trans = "asymmetric" @@ -2473,9 +2475,9 @@ def create_convert_map(self): "aten::clamp": self.clamp, "aten::clamp_": self.clamp, "aten::detach": self.identity, - "aten::upsample_bilinear2d": self.make_upsample("bilinear"), + "aten::upsample_bilinear2d": self.make_upsample("linear"), "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), - "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"), + "aten::upsample_trilinear3d": self.make_upsample3d("linear"), "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), "aten::expand_as": self.expand_as, "aten::lt": self.make_elemwise("less"), diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 004174f076fd..797ff51ace7a 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1075,7 +1075,7 @@ def _impl(inputs, attr, params, mod): # Ignore the new attributes from TF2.0, for now. return AttrCvt( - op_name="resize", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} + op_name="resize2d", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} )(inputs, attr) return _impl @@ -2943,8 +2943,8 @@ def _impl(inputs, attr, params, mod): "Relu": AttrCvt("relu"), "Relu6": _relu6(), "Reshape": _reshape(), - "ResizeBicubic": _resize("bilinear"), - "ResizeBilinear": _resize("bilinear"), + "ResizeBicubic": _resize("cubic"), + "ResizeBilinear": _resize("linear"), "ResizeNearestNeighbor": _resize("nearest_neighbor"), "ReverseV2": _reverse_v2(), "RightShift": AttrCvt("right_shift"), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 0dee44cd86ec..42096ad9af2f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -630,7 +630,7 @@ def _convert_resize(self, method, op): # Options - align_corners (bool) resize_options = None align_corners = False - bilinear_method = method == "bilinear" + bilinear_method = method == "linear" if bilinear_method: assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions resize_options = ResizeBilinearOptions() @@ -656,7 +656,7 @@ def _convert_resize(self, method, op): def convert_resize_bilinear(self, op): """Convert TFLite RESIZE_BILINEAR""" - return self._convert_resize("bilinear", op) + return self._convert_resize("linear", op) def convert_resize_nearest_neighbor(self, op): """Convert TFLite RESIZE_NEAREST_NEIGHBOR""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6fa6c97d76c0..e2c865b75b3d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3595,14 +3595,17 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method, coord_trans) verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method, coord_trans) - method = "linear" - # upsampling - verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) - # downsampling - verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) - # scales are specified instead of sizes - verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) - verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) + if ndim == 2: + ## TODO(mbrookhart): ONNX Runtime in CI only supports 2D linear resize + ## Remove this condition when updating CI + method = "linear" + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) if ndim == 2: # ONNX Runtime only supports cubic interpolation for 2D images From 1ca846a505b99c2635a002695885f562f672c4e5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 2 Jul 2021 14:28:20 -0600 Subject: [PATCH 09/12] refactor resize tests to reuse components, add more coordinate tranform modes to tests --- python/tvm/topi/testing/__init__.py | 4 +- .../topi/testing/bilinear_resize_python.py | 105 ---------- python/tvm/topi/testing/resize_python.py | 192 ++++++++++++++++++ .../topi/testing/trilinear_resize3d_python.py | 111 ---------- python/tvm/topi/testing/upsampling_python.py | 136 ------------- tests/python/frontend/coreml/test_forward.py | 11 +- tests/python/frontend/onnx/test_forward.py | 5 +- .../relay/dyn/test_dynamic_op_level2.py | 34 ++-- .../relay/dyn/test_dynamic_op_level5.py | 9 +- tests/python/relay/test_op_level2.py | 31 ++- tests/python/relay/test_op_level5.py | 24 +-- .../relay/test_pass_dynamic_to_static.py | 21 +- tests/python/topi/python/test_topi_image.py | 48 ++--- .../topi/python/test_topi_upsampling.py | 30 ++- 14 files changed, 290 insertions(+), 471 deletions(-) delete mode 100644 python/tvm/topi/testing/bilinear_resize_python.py create mode 100644 python/tvm/topi/testing/resize_python.py delete mode 100644 python/tvm/topi/testing/trilinear_resize3d_python.py delete mode 100644 python/tvm/topi/testing/upsampling_python.py diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index afb251417315..871059bf5ab4 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -35,9 +35,7 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python -from .upsampling_python import upsampling_python, upsampling3d_python -from .bilinear_resize_python import bilinear_resize_python -from .trilinear_resize3d_python import trilinear_resize3d_python +from .resize_python import resize2d_python, resize3d_python from .reorg_python import reorg_python from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python from .roi_pool_python import roi_pool_nchw_python diff --git a/python/tvm/topi/testing/bilinear_resize_python.py b/python/tvm/topi/testing/bilinear_resize_python.py deleted file mode 100644 index b1fb8b0b4845..000000000000 --- a/python/tvm/topi/testing/bilinear_resize_python.py +++ /dev/null @@ -1,105 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Bilinear Scale in python""" -import math -import numpy as np -from tvm.topi.utils import nchw_pack_layout - - -def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): - """Bilinear scaling using python""" - (new_h, new_w) = out_size - (ib, ic) = (1, 1) - - if layout == "NHWC": - (batch, h, w, channel) = image.shape - scaled_image = np.ones((batch, new_h, new_w, channel)) - # NCHWinic - elif nchw_pack_layout(layout): - (batch, channel, h, w, ib, ic) = image.shape - scaled_image = np.ones((batch, channel, new_h, new_w, ib, ic)) - else: - (batch, channel, h, w) = image.shape - scaled_image = np.ones((batch, channel, new_h, new_w)) - - if coordinate_transformation_mode == "align_corners": - height_scale = np.float32(h - 1) / np.float32(out_size[0] - 1) - width_scale = np.float32(w - 1) / np.float32(out_size[1] - 1) - else: - height_scale = np.float32(h) / np.float32(out_size[0]) - width_scale = np.float32(w) / np.float32(out_size[1]) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _img_scale(b, m, i, n): - for j in range(new_h): - for k in range(new_w): - if coordinate_transformation_mode == "half_pixel": - in_y = (j + 0.5) * height_scale - 0.5 - else: - in_y = j * height_scale - y0 = int(math.floor(in_y)) - y1 = max(min(y0 + 1, h - 1), 0) - y0 = max(y0, 0) - y_lerp = in_y - math.floor(in_y) - - if coordinate_transformation_mode == "half_pixel": - in_x = (k + 0.5) * width_scale - 0.5 - else: - in_x = k * width_scale - x0 = int(math.floor(in_x)) - x1 = max(min(x0 + 1, w - 1), 0) - x0 = max(x0, 0) - x_lerp = in_x - math.floor(in_x) - - if layout == "NHWC": - A = image[b][y0][x0][i] - B = image[b][y0][x1][i] - C = image[b][y1][x0][i] - D = image[b][y1][x1][i] - elif nchw_pack_layout(layout): - A = image[b][i][y0][x0][m][n] - B = image[b][i][y0][x1][m][n] - C = image[b][i][y1][x0][m][n] - D = image[b][i][y1][x1][m][n] - else: - A = image[b][i][y0][x0] - B = image[b][i][y0][x1] - C = image[b][i][y1][x0] - D = image[b][i][y1][x1] - - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - pixel = np.float32(_lerp(top, bottom, y_lerp)) - - if layout == "NHWC": - scaled_image[b][j][k][i] = pixel - elif nchw_pack_layout(layout): - scaled_image[b][i][j][k][m][n] = pixel - else: - scaled_image[b][i][j][k] = pixel - - for b in range(batch): - for m in range(ib): - for i in range(channel): - for n in range(ic): - _img_scale(b, m, i, n) - - return scaled_image diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py new file mode 100644 index 000000000000..23d41127a134 --- /dev/null +++ b/python/tvm/topi/testing/resize_python.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Upsampling in python""" +import math +import numpy as np +from tvm.topi.utils import nchw_pack_layout + + +def get_inx(x, image_width, target_width, coordinate_transformation_mode): + """Infer input x from output x with various coordinate transformation methods""" + scale = image_width / target_width + if coordinate_transformation_mode == "half_pixel": + in_x = (x + 0.5) * scale - 0.5 + elif coordinate_transformation_mode == "align_corners": + in_x = (image_width - 1) / (target_width - 1) * x if target_width > 1 else 0 + elif coordinate_transformation_mode == "asymmetric": + in_x = scale * x + else: + raise ValueError( + "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) + ) + return in_x + + +def get_index(x, image_width, target_width, coordinate_transformation_mode): + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + if coordinate_transformation_mode == "align_corners": + # round prefer ceil + out = int(math.floor(in_x + 0.5)) + else: + out = int(math.floor(in_x)) + out = max(min(out, image_width - 1), 0) + return out + + +def resize3d_nearest(arr, scale, coordinate_transformation_mode): + """Populate the array by scale factor""" + d, h, w = arr.shape + out_d, out_h, out_w = [int(round(i * s)) for i, s in zip(arr.shape, scale)] + out = np.empty((out_d, out_h, out_w)) + for z in range(out_d): + for y in range(out_h): + for x in range(out_w): + in_z = get_index(z, d, out_d, coordinate_transformation_mode) + in_y = get_index(y, h, out_h, coordinate_transformation_mode) + in_x = get_index(x, w, out_w, coordinate_transformation_mode) + out[z, y, x] = arr[in_z, in_y, in_x] + return out + + +def resize3d_linear(data_in, scale, coordinate_transformation_mode): + """Trilinear 3d scaling using python""" + d, h, w = data_in.shape + new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] + data_out = np.ones((new_d, new_h, new_w)) + + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _in_coord(new_coord, in_shape, out_shape): + in_coord = get_inx(new_coord, in_shape, out_shape, coordinate_transformation_mode) + coord0 = int(math.floor(in_coord)) + coord1 = max(min(coord0 + 1, in_shape - 1), 0) + coord0 = max(coord0, 0) + coord_lerp = in_coord - math.floor(in_coord) + return coord0, coord1, coord_lerp + + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + z0, z1, z_lerp = _in_coord(m, d, new_d) + y0, y1, y_lerp = _in_coord(j, h, new_h) + x0, x1, x_lerp = _in_coord(k, w, new_w) + + A0 = data_in[z0][y0][x0] + B0 = data_in[z0][y0][x1] + C0 = data_in[z0][y1][x0] + D0 = data_in[z0][y1][x1] + A1 = data_in[z1][y0][x0] + B1 = data_in[z1][y0][x1] + C1 = data_in[z1][y1][x0] + D1 = data_in[z1][y1][x1] + + A = _lerp(A0, A1, z_lerp) + B = _lerp(B0, B1, z_lerp) + C = _lerp(C0, C1, z_lerp) + D = _lerp(D0, D1, z_lerp) + top = _lerp(A, B, x_lerp) + bottom = _lerp(C, D, x_lerp) + + data_out[m][j][k] = np.float32(_lerp(top, bottom, y_lerp)) + + return data_out + + +def resize3d_ncdhw( + data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners" +): + ishape = data.shape + + oshape = ( + ishape[0], + ishape[1], + int(round(ishape[2] * scale[0])), + int(round(ishape[3] * scale[1])), + int(round(ishape[4] * scale[2])), + ) + + output_np = np.zeros(oshape, dtype=data.dtype) + + for b in range(oshape[0]): + for c in range(oshape[1]): + if method == "nearest_neighbor": + output_np[b, c, :, :, :] = resize3d_nearest( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + elif method == "linear": + output_np[b, c, :, :, :] = resize3d_linear( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + else: + raise ValueError("Unknown resize method", method) + + return output_np + + +def resize2d_python( + data, + scale, + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of scaling using nearest neighbour""" + + if layout == "NHWC": + data = data.transpose([0, 3, 1, 2]) + elif nchw_pack_layout(layout): + ishape = data.shape + transposed = data.transpose([0, 4, 1, 5, 2, 3]) + tshape = transposed.shape + data = transposed.reshape( + tshape[0] * tshape[1], tshape[2] * tshape[3], tshape[4], tshape[5] + ) + + data = np.expand_dims(data, axis=2) + + output_np = resize3d_ncdhw(data, (1,) + scale, method, coordinate_transformation_mode) + output_np = np.squeeze(output_np, axis=2) + + if layout == "NHWC": + output_np = output_np.transpose([0, 2, 3, 1]) + elif nchw_pack_layout(layout): + output_np = output_np.reshape(tshape[0:4] + output_np.shape[2:]) + output_np = output_np.transpose([0, 2, 4, 5, 1, 3]) + + return output_np + + +def resize3d_python( + data, + scale, + layout="NCDHW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of 3D scaling using nearest neighbour""" + + if layout == "NDHWC": + data = data.transpose([0, 4, 1, 2, 3]) + + output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode) + + if layout == "NDHWC": + output_np = output_np.transpose([0, 2, 3, 4, 1]) + + return output_np diff --git a/python/tvm/topi/testing/trilinear_resize3d_python.py b/python/tvm/topi/testing/trilinear_resize3d_python.py deleted file mode 100644 index d603e987d5ef..000000000000 --- a/python/tvm/topi/testing/trilinear_resize3d_python.py +++ /dev/null @@ -1,111 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks -"""Trilinear 3D resize in python""" -import math -import numpy as np - - -def trilinear_resize3d_python( - data_in, out_size, layout, coordinate_transformation_mode="align_corners" -): - """Trilinear 3d scaling using python""" - (new_d, new_h, new_w) = out_size - - if layout == "NDHWC": - (batch, d, h, w, channel) = data_in.shape - data_out = np.ones((batch, new_d, new_h, new_w, channel)) - else: - (batch, channel, d, h, w) = data_in.shape - data_out = np.ones((batch, channel, new_d, new_h, new_w)) - - if coordinate_transformation_mode == "align_corners": - depth_scale = np.float32(d - 1) / np.float32(out_size[0] - 1) - height_scale = np.float32(h - 1) / np.float32(out_size[1] - 1) - width_scale = np.float32(w - 1) / np.float32(out_size[2] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - depth_scale = np.float32(d) / np.float32(out_size[0]) - height_scale = np.float32(h) / np.float32(out_size[1]) - width_scale = np.float32(w) / np.float32(out_size[2]) - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) - ) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _in_coord(new_coord, scale, shape, mode): - if mode == "half_pixel": - in_coord = (new_coord + 0.5) * scale - 0.5 - else: - in_coord = new_coord * scale - coord0 = int(math.floor(in_coord)) - coord1 = max(min(coord0 + 1, shape - 1), 0) - coord0 = max(coord0, 0) - coord_lerp = in_coord - math.floor(in_coord) - return coord0, coord1, coord_lerp - - for b in range(batch): - for i in range(channel): - for m in range(new_d): - for j in range(new_h): - for k in range(new_w): - z0, z1, z_lerp = _in_coord( - m, depth_scale, d, coordinate_transformation_mode - ) - y0, y1, y_lerp = _in_coord( - j, height_scale, h, coordinate_transformation_mode - ) - x0, x1, x_lerp = _in_coord( - k, width_scale, w, coordinate_transformation_mode - ) - - if layout == "NDHWC": - A0 = data_in[b][z0][y0][x0][i] - B0 = data_in[b][z0][y0][x1][i] - C0 = data_in[b][z0][y1][x0][i] - D0 = data_in[b][z0][y1][x1][i] - A1 = data_in[b][z1][y0][x0][i] - B1 = data_in[b][z1][y0][x1][i] - C1 = data_in[b][z1][y1][x0][i] - D1 = data_in[b][z1][y1][x1][i] - else: - A0 = data_in[b][i][z0][y0][x0] - B0 = data_in[b][i][z0][y0][x1] - C0 = data_in[b][i][z0][y1][x0] - D0 = data_in[b][i][z0][y1][x1] - A1 = data_in[b][i][z1][y0][x0] - B1 = data_in[b][i][z1][y0][x1] - C1 = data_in[b][i][z1][y1][x0] - D1 = data_in[b][i][z1][y1][x1] - - A = _lerp(A0, A1, z_lerp) - B = _lerp(B0, B1, z_lerp) - C = _lerp(C0, C1, z_lerp) - D = _lerp(D0, D1, z_lerp) - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - pixel = np.float32(_lerp(top, bottom, y_lerp)) - - if layout == "NDHWC": - data_out[b][m][j][k][i] = pixel - else: - data_out[b][i][m][j][k] = pixel - - return data_out diff --git a/python/tvm/topi/testing/upsampling_python.py b/python/tvm/topi/testing/upsampling_python.py deleted file mode 100644 index dd187c4d8cff..000000000000 --- a/python/tvm/topi/testing/upsampling_python.py +++ /dev/null @@ -1,136 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Upsampling in python""" -import math -import numpy as np -from tvm.topi.utils import nchw_pack_layout - - -def upsample_nearest(arr, scale): - """Populate the array by scale factor""" - h, w = arr.shape - out_h = int(round(h * scale[0])) - out_w = int(round(w * scale[1])) - out = np.empty((out_h, out_w)) - for y in range(out_h): - for x in range(out_w): - in_y = math.floor(y / scale[0]) - in_x = math.floor(x / scale[1]) - out[y, x] = arr[in_y, in_x] - return out - - -def upsampling_python(data, scale, layout="NCHW"): - """Python version of scaling using nearest neighbour""" - - ishape = data.shape - if layout == "NCHW": - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) - return output_np - # NCHWinic - if nchw_pack_layout(layout): - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - ishape[4], - ishape[5], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for ib in range(oshape[4]): - for c in range(oshape[1]): - for ic in range(oshape[5]): - output_np[b, c, :, :, ib, ic] = upsample_nearest( - data[b, c, :, :, ib, ic], scale - ) - return output_np - - if layout == "NHWC": - oshape = ( - ishape[0], - int(round(ishape[1] * scale[0])), - int(round(ishape[2] * scale[1])), - ishape[3], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[3]): - output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale) - return output_np - raise ValueError("not support this layout {} yet".format(layout)) - - -def upsample3d_nearest(arr, scale): - """Populate the array by scale factor""" - d, h, w = arr.shape - out_d = int(round(d * scale[0])) - out_h = int(round(h * scale[1])) - out_w = int(round(w * scale[2])) - out = np.empty((out_d, out_h, out_w)) - for z in range(out_d): - for y in range(out_h): - for x in range(out_w): - in_z = math.floor(z / scale[0]) - in_y = math.floor(y / scale[1]) - in_x = math.floor(x / scale[2]) - out[z, y, x] = arr[in_z, in_y, in_x] - return out - - -def upsampling3d_python(data, scale, layout="NCDHW"): - """Python version of 3D scaling using nearest neighbour""" - - ishape = data.shape - if layout == "NCDHW": - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - int(round(ishape[4] * scale[2])), - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :, :] = upsample3d_nearest(data[b, c, :, :, :], scale) - return output_np - if layout == "NDHWC": - oshape = ( - ishape[0], - int(round(ishape[1] * scale[0])), - int(round(ishape[2] * scale[1])), - int(round(ishape[3] * scale[2])), - ishape[4], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[4]): - output_np[b, :, :, :, c] = upsample3d_nearest(data[b, :, :, :, c], scale) - return output_np - raise ValueError("not support this layout {} yet".format(layout)) diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 72dac9b2501f..ee9159573ea2 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -206,12 +206,15 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): dtype = "float32" a_np = np.full(input_dim, 1, dtype=dtype) + if mode == "NN": - b_np = tvm.topi.testing.upsampling_python(a_np, (scale, scale)) + method = "nearest_neighbor" + coord_trans = "asymmetric" else: - new_h = input_dim[2] * scale - new_w = input_dim[3] * scale - b_np = tvm.topi.testing.bilinear_resize_python(a_np, (new_h, new_w), "NCHW") + method = "linear" + coord_trans = "align_corners" + + b_np = tvm.topi.testing.resize2d_python(a_np, (scale, scale), "NCHW", method, coord_trans) input = [("input", datatypes.Array(*input_dim))] output = [("output", datatypes.Array(*b_np.shape))] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e2c865b75b3d..3c1098c2c1cd 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1366,10 +1366,11 @@ def verify_upsample3d_trilinear(): y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear") scales = [1.0, 1.0, 2.0, 2.0, 2.0] in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.trilinear_resize3d_python( + out_array = tvm.topi.testing.resize3d_python( in_array, - (3 * scale, 3 * scale, 3 * scale), + (scale, scale, scale), "NCDHW", + "linear", coordinate_transformation_mode="asymmetric", ) diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index 8e0a11a6c9d2..a6ea609be1e2 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -40,12 +40,13 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa (n, h, w, c) = dshape x_data = np.random.uniform(size=(n, h, w, c)).astype("float32") - if method == "nearest_neighbor": - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout) - else: - ref_res = tvm.topi.testing.bilinear_resize_python( - x_data, (int(round(h * scale_h)), int(round(w * scale_w))), layout - ) + ref_res = tvm.topi.testing.resize2d_python( + x_data, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "align_corners" if align_corners else "asymmetric", + ) x = relay.Var("x", relay.TensorType(dshape, "float32")) scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) scale_w_var = relay.var("scale_h", relay.TensorType((), "float32")) @@ -98,19 +99,14 @@ def verify_upsampling3d( (n, d, h, w, c) = dshape x_data = np.random.uniform(size=(n, d, h, w, c)).astype("float32") - if method == "nearest_neighbor": - assert ( - coord_trans == "asymmetric" - ), "topi reference only support asymmetric nearest neighbor" - ref_res = tvm.topi.testing.upsampling3d_python( - x_data, (scale_d, scale_h, scale_w), layout - ) - else: - ref_res = tvm.topi.testing.trilinear_resize3d_python( - x_data, - (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))), - layout, - ) + ref_res = tvm.topi.testing.resize3d_python( + x_data, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + coord_trans, + ) + x = relay.Var("x", relay.TensorType(dshape, "float32")) scale_d_var = relay.var("scale_d", relay.TensorType((), "float32")) scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 305fe4d2380f..d3459afaab06 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -45,10 +45,7 @@ def verify_resize2d(dshape, scale, method, layout): size = (dshape[2] * scale, dshape[3] * scale) size = np.array(size).astype("int64") x_data = np.random.uniform(size=dshape).astype("float32") - if method == "linear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + x = relay.var("x", relay.TensorType(dshape, "float32")) size_var = relay.var("size", relay.TensorType((2,), "int64")) @@ -60,6 +57,10 @@ def verify_resize2d(dshape, scale, method, layout): zz = run_infer_type(z) func = relay.Function([x, size_var], z) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) + for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 907ce79d82f4..f05c5054415d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1450,12 +1450,13 @@ def get_shape(): func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - if method == "nearest_neighbor": - ref = tvm.topi.testing.upsampling_python(data, (scale_h, scale_w), layout) - else: - ref = tvm.topi.testing.bilinear_resize_python( - data, (int(round(h * scale_h)), int(round(w * scale_w))), layout - ) + ref = tvm.topi.testing.resize2d_python( + data, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "align_corners" if align_corners else "asymmetric", + ) for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) @@ -1521,17 +1522,13 @@ def get_shape(): func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - if method == "nearest_neighbor": - assert ( - coordinate_transformation_mode == "asymmetric" - ), "topi reference only support asymmetric nearest neighbor" - ref = tvm.topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout) - else: - ref = tvm.topi.testing.trilinear_resize3d_python( - data, - (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))), - layout, - ) + ref = tvm.topi.testing.resize3d_python( + data, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + coordinate_transformation_mode, + ) for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 9b12b9d12535..9d74d85981ae 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -51,10 +51,9 @@ def verify_resize(dshape, scale, method, layout, coord_trans): x_data = np.random.uniform(size=dshape).astype("float32") - if method == "linear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout, coord_trans) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize2d( x, size, layout, method, coordinate_transformation_mode=coord_trans @@ -63,7 +62,6 @@ def verify_resize(dshape, scale, method, layout, coord_trans): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) @@ -71,7 +69,7 @@ def verify_resize(dshape, scale, method, layout, coord_trans): tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "linear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) @@ -109,10 +107,9 @@ def verify_resize(dshape, scale, method, layout): size = (dshape[2] * scale, dshape[3] * scale, dshape[4] * scale) x_data = np.random.uniform(size=dshape).astype("float32") - if method == "linear": - ref_res = tvm.topi.testing.trilinear_resize3d_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale, scale, scale), layout) + ref_res = tvm.topi.testing.resize3d_python( + x_data, (scale, scale, scale), layout, method, "align_corners" + ) x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize3d(x, size, layout, method, "align_corners") assert "size=" in z.astext() @@ -125,9 +122,10 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["linear", "nearest_neighbor"]: - for layout in ["NDHWC", "NCDHW"]: - verify_resize((1, 4, 4, 4, 4), 2, method, layout) + for method in ["nearest_neighbor", "linear"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NDHWC", "NCDHW"]: + verify_resize((1, 4, 4, 4, 4), 2, method, layout) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 09f4b43449d7..ef9d8b2c4596 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -270,12 +270,9 @@ def verify_resize(shape, scale, method, layout): assert zz.op == relay.op.get("image.resize2d") x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") - - if method == "linear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) - verify_func(func2, [x_data], ref_res, rtol=1e-4, atol=1e-6) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: @@ -347,7 +344,9 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): assert zz.op == relay.op.get("nn.upsampling") x_data = np.random.uniform(size=data_shape).astype(dtype) - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW") + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), "NCHW", "nearest_neighbor", "asymmetric" + ) verify_func(func2, [x_data], ref_res) verify_upsampling((1, 16, 32, 32), 2, 2, "int8") @@ -371,8 +370,12 @@ def verify_upsampling3d(data_shape, scale_d_val, scale_h_val, scale_w_val, dtype assert zz.op == relay.op.get("nn.upsampling3d") x_data = np.random.uniform(size=data_shape).astype(dtype) - ref_res = tvm.topi.testing.upsampling3d_python( - x_data, (scale_d_val, scale_h_val, scale_w_val), "NCDHW" + ref_res = tvm.topi.testing.resize3d_python( + x_data, + (scale_d_val, scale_h_val, scale_w_val), + "NCDHW", + "nearest_neighbor", + "asymmetric", ) verify_func(func2, [x_data], ref_res) diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 381cdc08d890..fe7fba52f1ee 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -54,17 +54,9 @@ def verify_resize2d( coordinate_transformation_mode=coord_trans, method=method, ) - if method == "linear": - b_np = tvm.topi.testing.bilinear_resize_python( - a_np, (out_height, out_width), layout, coord_trans - ) - else: - # TODO: Nearest neighbor case doesn't do anything with coordinate transform mode, and also - # nearest_neighbors and align_corners combination in topi doesn't match the output of this - # function. - scale_h = out_height / in_height - scale_w = out_width / in_width - b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = tvm.topi.testing.resize2d_python(a_np, (scale_h, scale_w), layout, method, coord_trans) def check_target(target, dev): print("Running on target: %s" % target) @@ -93,8 +85,10 @@ def test_resize2d(): verify_resize2d(6, 32, 64, 64, 20, 20, "NHWC") for layout in ["NCHW", "NHWC"]: verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="linear") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "align_corners", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="nearest_neighbor") verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="linear") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="linear") def verify_resize3d( @@ -139,15 +133,12 @@ def verify_resize3d( method=method, ) - if method == "linear": - b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, (out_depth, out_height, out_width), layout, coordinate_transformation_mode - ) - else: - scale_d = out_depth / in_depth - scale_h = out_height / in_height - scale_w = out_width / in_width - b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + scale_d = out_depth / in_depth + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = tvm.topi.testing.resize3d_python( + a_np, (scale_d, scale_h, scale_w), layout, method, coordinate_transformation_mode + ) def check_target(target, dev): with tvm.target.Target(target): @@ -166,17 +157,10 @@ def check_target(target, dev): @tvm.testing.uses_gpu def test_resize3d(): # Trilinear - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "half_pixel") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "half_pixel") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "align_corners") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "align_corners") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "asymmetric") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "asymmetric") - - # Nearest neighbor - # Test kernel only supports asymmetric - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW", method="nearest_neighbor") - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NDHWC", method="nearest_neighbor") + for method in ["nearest_neighbor", "linear"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NCDHW", "NDHWC"]: + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, layout, coord_trans, method) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_upsampling.py b/tests/python/topi/python/test_topi_upsampling.py index 8dfe7d7a24b8..7793417a9a2b 100644 --- a/tests/python/topi/python/test_topi_upsampling.py +++ b/tests/python/topi/python/test_topi_upsampling.py @@ -78,11 +78,13 @@ def verify_upsampling( B = topi.nn.upsampling(A, scale_h, scale_w, layout=layout, method=method, align_corners=False) - if method == "bilinear": - out_size = (int(round(in_height * scale_h)), int(round(in_width * scale_w))) - b_np = tvm.topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric") - else: - b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) + b_np = tvm.topi.testing.resize2d_python( + a_np, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "asymmetric", + ) def check_target(target, dev): print("Running on target: %s" % target) @@ -216,17 +218,13 @@ def verify_upsampling3d( coordinate_transformation_mode="asymmetric", ) - if method == "trilinear": - out_size = ( - int(round(in_depth * scale_d)), - int(round(in_height * scale_h)), - int(round(in_width * scale_w)), - ) - b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, out_size, layout, coordinate_transformation_mode="asymmetric" - ) - else: - b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + b_np = tvm.topi.testing.resize3d_python( + a_np, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + "asymmetric", + ) def check_target(target, dev): print("Running on target: %s" % target) From 49dc7a9d0469250b3b290875d12b569b823ea760 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 2 Jul 2021 15:07:18 -0600 Subject: [PATCH 10/12] add cubic resize reference kernel and tests, add relay tests for resize1d --- python/tvm/topi/image/resize.py | 8 +- python/tvm/topi/testing/__init__.py | 2 +- python/tvm/topi/testing/resize_python.py | 97 +++++++++++++++++++++++- tests/python/relay/test_op_level5.py | 57 +++++++++++++- 4 files changed, 154 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index e8d071ae73b2..5d9d96036282 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -1085,10 +1085,10 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): wy = [w / sum_wy for w in wy] wx = [w / sum_wx for w in wx] - l = [[0 for i in range(2)] for j in range(2)] - for j in range(2): - for i in range(2): - l[j][i] = _cubic_kerel(p[j][i], wx) + l = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + l[j][i] = _cubic_kernel(p[j][i], wx) col0 = _cubic_kernel(l[0], wy) col1 = _cubic_kernel(l[1], wy) col2 = _cubic_kernel(l[2], wy) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 871059bf5ab4..e23ecfa8fc69 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -35,7 +35,7 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python -from .resize_python import resize2d_python, resize3d_python +from .resize_python import resize1d_python, resize2d_python, resize3d_python from .reorg_python import reorg_python from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python from .roi_pool_python import roi_pool_nchw_python diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py index 23d41127a134..f269b8e8b868 100644 --- a/python/tvm/topi/testing/resize_python.py +++ b/python/tvm/topi/testing/resize_python.py @@ -108,6 +108,75 @@ def _in_coord(new_coord, in_shape, out_shape): return data_out +def resize3d_cubic(data_in, scale, coordinate_transformation_mode): + """Tricubic 3d scaling using python""" + d, h, w = data_in.shape + new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] + data_out = np.ones((new_d, new_h, new_w)) + + def _cubic_spline_weights(t, alpha=-0.5): + """create cubic spline weights in 1D""" + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] + + def _cubic_kernel(inputs, w): + """perform cubic interpolation in 1D""" + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + def _get_input_value(z, y, x): + z = max(min(z, d - 1), 0) + y = max(min(y, h - 1), 0) + x = max(min(x, w - 1), 0) + return data_in[z][y][x] + + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + in_z = get_inx(m, d, new_d, coordinate_transformation_mode) + in_y = get_inx(j, h, new_h, coordinate_transformation_mode) + in_x = get_inx(k, w, new_w, coordinate_transformation_mode) + zint = math.floor(in_z) + zfract = in_z - math.floor(in_z) + + yint = math.floor(in_y) + yfract = in_y - math.floor(in_y) + + xint = math.floor(in_x) + xfract = in_x - math.floor(in_x) + + # Get the surrounding values + p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] + for kk in range(4): + for jj in range(4): + for ii in range(4): + p[kk][jj][ii] = _get_input_value( + zint + kk - 1, + yint + jj - 1, + xint + ii - 1, + ) + + wz = _cubic_spline_weights(zfract) + wy = _cubic_spline_weights(yfract) + wx = _cubic_spline_weights(xfract) + + l = [[0 for i in range(4)] for j in range(4)] + for jj in range(4): + for ii in range(4): + l[jj][ii] = _cubic_kernel(p[jj][ii], wx) + col0 = _cubic_kernel(l[0], wy) + col1 = _cubic_kernel(l[1], wy) + col2 = _cubic_kernel(l[2], wy) + col3 = _cubic_kernel(l[3], wy) + data_out[m][j][k] = _cubic_kernel([col0, col1, col2, col3], wz) + + return data_out + + def resize3d_ncdhw( data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners" ): @@ -133,12 +202,39 @@ def resize3d_ncdhw( output_np[b, c, :, :, :] = resize3d_linear( data[b, c, :, :, :], scale, coordinate_transformation_mode ) + elif method == "cubic": + output_np[b, c, :, :, :] = resize3d_cubic( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) else: raise ValueError("Unknown resize method", method) return output_np +def resize1d_python( + data, + scale, + layout="NCW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of 3D scaling using nearest neighbour""" + + if layout == "NWC": + data = data.transpose([0, 2, 1]) + + data = np.expand_dims(data, axis=[2, 3]) + output_np = resize3d_ncdhw(data, (1, 1) + scale, method, coordinate_transformation_mode) + output_np = np.squeeze(output_np, axis=2) + output_np = np.squeeze(output_np, axis=2) + + if layout == "NWC": + output_np = output_np.transpose([0, 2, 1]) + + return output_np + + def resize2d_python( data, scale, @@ -159,7 +255,6 @@ def resize2d_python( ) data = np.expand_dims(data, axis=2) - output_np = resize3d_ncdhw(data, (1,) + scale, method, coordinate_transformation_mode) output_np = np.squeeze(output_np, axis=2) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 9d74d85981ae..d93de5419f56 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -26,7 +26,56 @@ from tvm.relay.testing import run_infer_type -def test_resize_infer_type(): +def test_resize1d_infer_type(): + n, c, w = te.size_var("n"), te.size_var("c"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + tw = te.var("tw") + z = relay.image.resize1d(x, (tw,)) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, tw), "int8") + + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + z = relay.image.resize1d(x, (200,), "NCW", "linear", "align_corners") + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, 200), "int8") + + +@tvm.testing.uses_gpu +def test_resize1d(): + def verify_resize(dshape, scale, method, layout, coord_trans): + if layout == "NWC": + size = (dshape[1] * scale,) + else: + size = (dshape[2] * scale,) + + x_data = np.random.uniform(size=dshape).astype("float32") + + ref_res = tvm.topi.testing.resize1d_python(x_data, (scale,), layout, method, coord_trans) + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.image.resize1d( + x, size, layout, method, coordinate_transformation_mode=coord_trans + ) + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) + + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NWC", "NCW"]: + verify_resize((1, 4, 4), 2, method, layout, coord_trans) + verify_resize((2, 8, 17), 3, method, layout, coord_trans) + verify_resize((2, 8, 17), 3, method, layout, coord_trans) + verify_resize((3, 4, 5), 5, method, layout, coord_trans) + + +def test_resize2d_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) th, tw = te.var("th"), te.var("tw") @@ -42,7 +91,7 @@ def test_resize_infer_type(): @tvm.testing.uses_gpu -def test_resize(): +def test_resize2d(): def verify_resize(dshape, scale, method, layout, coord_trans): if layout == "NHWC": size = (dshape[1] * scale, dshape[2] * scale) @@ -68,7 +117,7 @@ def verify_resize(dshape, scale, method, layout, coord_trans): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) - for method in ["nearest_neighbor", "linear"]: + for method in ["nearest_neighbor", "linear", "cubic"]: for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) @@ -122,7 +171,7 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["nearest_neighbor", "linear"]: + for method in ["nearest_neighbor", "linear", "cubic"]: for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: for layout in ["NDHWC", "NCDHW"]: verify_resize((1, 4, 4, 4, 4), 2, method, layout) From 05adba5e8907d02bde7573f43e4800f9fb980586 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 2 Jul 2021 15:15:53 -0600 Subject: [PATCH 11/12] fix pylint --- python/tvm/topi/testing/resize_python.py | 29 +++++++++++++++--------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py index f269b8e8b868..e8d5c0599887 100644 --- a/python/tvm/topi/testing/resize_python.py +++ b/python/tvm/topi/testing/resize_python.py @@ -38,6 +38,7 @@ def get_inx(x, image_width, target_width, coordinate_transformation_mode): def get_index(x, image_width, target_width, coordinate_transformation_mode): + """get and round the nearest index for nearest_neighbor""" in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) if coordinate_transformation_mode == "align_corners": # round prefer ceil @@ -134,6 +135,19 @@ def _get_input_value(z, y, x): x = max(min(x, w - 1), 0) return data_in[z][y][x] + def _get_patch(zint, yint, xint): + # Get the surrounding values + p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] + for kk in range(4): + for jj in range(4): + for ii in range(4): + p[kk][jj][ii] = _get_input_value( + zint + kk - 1, + yint + jj - 1, + xint + ii - 1, + ) + return p + for m in range(new_d): for j in range(new_h): for k in range(new_w): @@ -149,25 +163,17 @@ def _get_input_value(z, y, x): xint = math.floor(in_x) xfract = in_x - math.floor(in_x) - # Get the surrounding values - p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] - for kk in range(4): - for jj in range(4): - for ii in range(4): - p[kk][jj][ii] = _get_input_value( - zint + kk - 1, - yint + jj - 1, - xint + ii - 1, - ) - wz = _cubic_spline_weights(zfract) wy = _cubic_spline_weights(yfract) wx = _cubic_spline_weights(xfract) + p = _get_patch(zint, yint, xint) + l = [[0 for i in range(4)] for j in range(4)] for jj in range(4): for ii in range(4): l[jj][ii] = _cubic_kernel(p[jj][ii], wx) + col0 = _cubic_kernel(l[0], wy) col1 = _cubic_kernel(l[1], wy) col2 = _cubic_kernel(l[2], wy) @@ -180,6 +186,7 @@ def _get_input_value(z, y, x): def resize3d_ncdhw( data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners" ): + """reference kernel for 3D image resizing""" ishape = data.shape oshape = ( From c2726b8a16116e1954e67f9565fbd5c509057670 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 2 Jul 2021 17:24:29 -0600 Subject: [PATCH 12/12] fix test typo --- tests/python/relay/test_pass_dynamic_to_static.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index ef9d8b2c4596..962b7bebb12b 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -345,7 +345,7 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): x_data = np.random.uniform(size=data_shape).astype(dtype) ref_res = tvm.topi.testing.resize2d_python( - x_data, (scale, scale), "NCHW", "nearest_neighbor", "asymmetric" + x_data, (scale_h_val, scale_w_val), "NCHW", "nearest_neighbor", "asymmetric" ) verify_func(func2, [x_data], ref_res)