diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d65f5676fb33..0b6ebdb5d5c2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -622,6 +622,62 @@ class MaxPool(Pool): name = "max_pool" +class MaxUnpool(OnnxOpConverter): + """Operator converter for MaxUnpool""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + # Unpack inputs and attributes + data = inputs[0] + data_type = infer_type(data).checked_type.dtype + indices = inputs[1] + output_shape = inputs[2] + kernel_shape = attr.get("kernel_shape") + pads = attr.get("pads", None) + strides = attr.get("strides", [1] * len(kernel_shape)) + + # Compute the proper output shape before padding. + multiplier = _op.concatenate( + [_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0 + ) + total_output_shape = multiplier * _op.shape_of(data, dtype="int64") + # Add extra dimensions from kernel size and stride mismatch + total_output_shape += _op.concatenate( + [_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0 + ) - _op.concatenate( + [_expr.const([0, 0], "int64"), _expr.const(list(strides), "int64")], axis=0 + ) + + # Compute padding amount if output shape is specified. + if output_shape is not None: + total_output_shape = output_shape + + elif pads is not None: + # Get pads in the proper format for relay. + pads = _op.concatenate( + [_expr.const([0, 0, 0, 0], "int64"), _expr.const(list(pads), "int64")], axis=0 + ) + pads = _op.reshape(pads, [-1, 2]) + # Compute the total padding per axis. + total_pad = _op.sum(pads, axis=-1) + # Reversing maxpool means that padding actually makes our output smaller. + total_output_shape = total_output_shape - total_pad + + # Create a tensor of zeros then scatter our data through it. + zeros_tensor = _op.zeros(total_output_shape, data_type) + # We need to flatten all our tensors before scattering. + flat_tensor = _op.scatter( + _op.reshape(zeros_tensor, [-1]), + _op.reshape(indices, [-1]), + _op.reshape(data, [-1]), + axis=0, + ) + # Now reshape back to prepadded shape. + output_tensor = _op.reshape(flat_tensor, total_output_shape) + + return output_tensor + + class LpPool(OnnxOpConverter): """A helper class for lppool op converters.""" @@ -2330,6 +2386,7 @@ def _get_convert_map(opset): "AveragePool": AveragePool.get_converter(opset), "LpPool": LpPool.get_converter(opset), "MaxPool": MaxPool.get_converter(opset), + "MaxUnpool": MaxUnpool.get_converter(opset), "Conv": Conv.get_converter(opset), "ConvTranspose": ConvTranspose.get_converter(opset), "GlobalAveragePool": Renamer("global_avg_pool2d"), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3ddc80af3a32..1e0b729cbef0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3915,6 +3915,74 @@ def verify_size(indata): verify_size(input_data) +@tvm.testing.uses_gpu +def test_maxunpool(): + def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pads=None): + input_names = ["xT", "xI"] + input_info = [ + helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)), + helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape)), + ] + input_values = [data, indices] + if output_shape is not None: + input_names.append("output_shape") + input_info.append( + helper.make_tensor_value_info( + "output_shape", TensorProto.INT64, list(output_shape.shape) + ) + ) + input_values.append(output_shape) + else: + # Compute expected output shape + output_shape = np.asarray(([1, 1] + list(strides))) * np.asarray(list(data.shape)) + output_shape += np.asarray(([0, 0] + list(kernel_shape))) - np.asarray( + ([0, 0] + list(strides)) + ) + if pads is not None: + output_shape -= np.asarray( + [0, 0] + list(np.sum(np.reshape(list(pads), [-1, 2]), axis=-1)) + ) + output_shape = [int(i) for i in output_shape] + + node = helper.make_node( + "MaxUnpool", inputs=input_names, outputs=["y"], kernel_shape=kernel_shape + ) + + if pads is not None: + pad_attr = helper.make_attribute("pads", pads) + node.attribute.append(pad_attr) + + if strides is not None: + strides_attr = helper.make_attribute("strides", strides) + node.attribute.append(strides_attr) + + graph = helper.make_graph( + [node], + "maxunpool_test", + inputs=input_info, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + + model = helper.make_model(graph, producer_name="size_test") + + verify_with_ort_with_inputs(model, input_values, use_vm=True, opset=11) + + # Basic test + xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32) + xI = np.array([[[[0, 7], [13, 15]]]], dtype=np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2]) + # Small stride + verify_maxunpool(xT, xI, [2, 2], strides=[1, 1]) + # Big kernel + verify_maxunpool(xT, xI, [3, 3], strides=[2, 2]) + # With output shape + output_shape = np.array((1, 1, 5, 5), dtype=np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], output_shape=output_shape) + # With explicit reverse padding + pads = np.asarray([1, 1, 1, 1]).astype(np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], pads=pads) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -3992,3 +4060,4 @@ def verify_size(indata): test_range() test_loop() test_size() + test_maxunpool()