Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,62 @@ class MaxPool(Pool):
name = "max_pool"


class MaxUnpool(OnnxOpConverter):
"""Operator converter for MaxUnpool"""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Unpack inputs and attributes
data = inputs[0]
data_type = infer_type(data).checked_type.dtype
indices = inputs[1]
output_shape = inputs[2]
kernel_shape = attr.get("kernel_shape")
pads = attr.get("pads", None)
strides = attr.get("strides", [1] * len(kernel_shape))

# Compute the proper output shape before padding.
multiplier = _op.concatenate(
[_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0
)
total_output_shape = multiplier * _op.shape_of(data, dtype="int64")
# Add extra dimensions from kernel size and stride mismatch
total_output_shape += _op.concatenate(
[_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0
) - _op.concatenate(
[_expr.const([0, 0], "int64"), _expr.const(list(strides), "int64")], axis=0
)

# Compute padding amount if output shape is specified.
if output_shape is not None:
total_output_shape = output_shape

elif pads is not None:
# Get pads in the proper format for relay.
pads = _op.concatenate(
[_expr.const([0, 0, 0, 0], "int64"), _expr.const(list(pads), "int64")], axis=0
)
pads = _op.reshape(pads, [-1, 2])
# Compute the total padding per axis.
total_pad = _op.sum(pads, axis=-1)
# Reversing maxpool means that padding actually makes our output smaller.
total_output_shape = total_output_shape - total_pad

# Create a tensor of zeros then scatter our data through it.
zeros_tensor = _op.zeros(total_output_shape, data_type)
# We need to flatten all our tensors before scattering.
flat_tensor = _op.scatter(
_op.reshape(zeros_tensor, [-1]),
_op.reshape(indices, [-1]),
_op.reshape(data, [-1]),
axis=0,
)
# Now reshape back to prepadded shape.
output_tensor = _op.reshape(flat_tensor, total_output_shape)

return output_tensor


class LpPool(OnnxOpConverter):
"""A helper class for lppool op converters."""

Expand Down Expand Up @@ -2330,6 +2386,7 @@ def _get_convert_map(opset):
"AveragePool": AveragePool.get_converter(opset),
"LpPool": LpPool.get_converter(opset),
"MaxPool": MaxPool.get_converter(opset),
"MaxUnpool": MaxUnpool.get_converter(opset),
"Conv": Conv.get_converter(opset),
"ConvTranspose": ConvTranspose.get_converter(opset),
"GlobalAveragePool": Renamer("global_avg_pool2d"),
Expand Down
69 changes: 69 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3915,6 +3915,74 @@ def verify_size(indata):
verify_size(input_data)


@tvm.testing.uses_gpu
def test_maxunpool():
def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pads=None):
input_names = ["xT", "xI"]
input_info = [
helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)),
helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape)),
]
input_values = [data, indices]
if output_shape is not None:
input_names.append("output_shape")
input_info.append(
helper.make_tensor_value_info(
"output_shape", TensorProto.INT64, list(output_shape.shape)
)
)
input_values.append(output_shape)
else:
# Compute expected output shape
output_shape = np.asarray(([1, 1] + list(strides))) * np.asarray(list(data.shape))
output_shape += np.asarray(([0, 0] + list(kernel_shape))) - np.asarray(
([0, 0] + list(strides))
)
if pads is not None:
output_shape -= np.asarray(
[0, 0] + list(np.sum(np.reshape(list(pads), [-1, 2]), axis=-1))
)
output_shape = [int(i) for i in output_shape]

node = helper.make_node(
"MaxUnpool", inputs=input_names, outputs=["y"], kernel_shape=kernel_shape
)

if pads is not None:
pad_attr = helper.make_attribute("pads", pads)
node.attribute.append(pad_attr)

if strides is not None:
strides_attr = helper.make_attribute("strides", strides)
node.attribute.append(strides_attr)

graph = helper.make_graph(
[node],
"maxunpool_test",
inputs=input_info,
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)],
)

model = helper.make_model(graph, producer_name="size_test")

verify_with_ort_with_inputs(model, input_values, use_vm=True, opset=11)

# Basic test
xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32)
xI = np.array([[[[0, 7], [13, 15]]]], dtype=np.int64)
verify_maxunpool(xT, xI, [2, 2], strides=[2, 2])
# Small stride
verify_maxunpool(xT, xI, [2, 2], strides=[1, 1])
# Big kernel
verify_maxunpool(xT, xI, [3, 3], strides=[2, 2])
# With output shape
output_shape = np.array((1, 1, 5, 5), dtype=np.int64)
verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], output_shape=output_shape)
# With explicit reverse padding
pads = np.asarray([1, 1, 1, 1]).astype(np.int64)
verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], pads=pads)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -3992,3 +4060,4 @@ def verify_size(indata):
test_range()
test_loop()
test_size()
test_maxunpool()