diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 462d1cf92c01..5777f51fe296 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -34,14 +34,15 @@ Not all TVM kernels currently support dynamic shapes, please file an issue on github.com/apache/tvm/issues if you hit an error with dynamic kernels. """ +import math import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as _np import onnx.onnx_ml_pb2 import tvm -from tvm import relax, tir, topi +from tvm import TVMError, relax, tir, topi from tvm.ir import IRModule from tvm.ir.supply import NameSupply from tvm.tir.generic import cast @@ -236,28 +237,176 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.matmul(inputs[0], inputs[1]) -class Div(OnnxOpConverter): - """Converts an onnx Div node into an equivalent Relax expression.""" +class BinaryBase(OnnxOpConverter): + """Converts an onnx BinaryBase node into an equivalent Relax expression.""" + + numpy_op: Callable = None + relax_op: Callable = None @classmethod - def _impl_v14(cls, bb, inputs, attr, params): + def _impl_v1(cls, bb, inputs, attr, params): + if cls.numpy_op is None or cls.relax_op is None: + raise ValueError("Numpy and Relax operators must be defined for BinaryBase.") if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() / inputs[1].data.numpy() + output = cls.numpy_op( # pylint: disable=not-callable + inputs[0].data.numpy(), inputs[1].data.numpy() + ) return relax.const(output, inputs[0].struct_info.dtype) if any([isinstance(inp, relax.PrimValue) for inp in inputs]): x = ( - int(inputs[0].value) + _np.array(inputs[0].value) if isinstance(inputs[0], relax.PrimValue) else inputs[0].data.numpy() ) y = ( - int(inputs[1].value) + _np.array(inputs[0].value) if isinstance(inputs[1], relax.PrimValue) else inputs[1].data.numpy() ) - return relax.PrimValue(int(x / y)) + return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable + + return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable + + +class Add(BinaryBase): + """Converts an onnx Add node into an equivalent Relax expression.""" + + numpy_op = _np.add + relax_op = relax.op.add + + +class Sub(BinaryBase): + """Converts an onnx Sub node into an equivalent Relax expression.""" + + numpy_op = _np.subtract + relax_op = relax.op.subtract + + +class Mul(BinaryBase): + """Converts an onnx Mul node into an equivalent Relax expression.""" + + numpy_op = _np.multiply + relax_op = relax.op.multiply + + +class Div(BinaryBase): + """Converts an onnx Div node into an equivalent Relax expression.""" + + numpy_op = _np.divide + relax_op = relax.op.divide + + +class Pow(BinaryBase): + """Converts an onnx Pow node into an equivalent Relax expression.""" + + numpy_op = _np.power + relax_op = relax.op.power + + +class And(BinaryBase): + """Converts an onnx And node into an equivalent Relax expression.""" + + numpy_op = _np.logical_and + relax_op = relax.op.logical_and - return relax.op.divide(inputs[0], inputs[1]) + +class Or(BinaryBase): + """Converts an onnx Or node into an equivalent Relax expression.""" + + numpy_op = _np.logical_or + relax_op = relax.op.logical_or + + +class Xor(BinaryBase): + """Converts an onnx Xor node into an equivalent Relax expression.""" + + numpy_op = _np.logical_xor + relax_op = relax.op.logical_xor + + +class Less(BinaryBase): + """Converts an onnx Less node into an equivalent Relax expression.""" + + numpy_op = _np.less + relax_op = relax.op.less + + +class LessOrEqual(BinaryBase): + """Converts an onnx LessEqual node into an equivalent Relax expression.""" + + numpy_op = _np.less_equal + relax_op = relax.op.less_equal + + +class Greater(BinaryBase): + """Converts an onnx Greater node into an equivalent Relax expression.""" + + numpy_op = _np.greater + relax_op = relax.op.greater + + +class GreaterOrEqual(BinaryBase): + """Converts an onnx GreaterEqual node into an equivalent Relax expression.""" + + numpy_op = _np.greater_equal + relax_op = relax.op.greater_equal + + +class Equal(OnnxOpConverter): + """Converts an onnx Equal node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + if all([isinstance(inp, relax.Constant) for inp in inputs]): + output = inputs[0].data.numpy() == inputs[1].data.numpy() + return relax.const(output, output.dtype) + elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp in inputs]): + lhs = get_prim_expr_list(inputs[0]) + rhs = get_prim_expr_list(inputs[1]) + if len(lhs) != len(rhs): + raise ValueError("Cannot compare two tensors with different shapes") + output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)] + return relax.const(output, "bool") + return relax.op.equal(inputs[0], inputs[1]) + + +class BitwiseBase(BinaryBase): + """Converts an onnx BitwiseBase node into an equivalent Relax expression.""" + + @classmethod + def base_impl(cls, bb, inputs, attr, params, py_func, relax_op): + valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] + for num, inp in enumerate(inputs): + if inp.struct_info.dtype not in valid_types: + raise ValueError( + f"Bitwise operations expect all inputs to have integer types, " + f"got {inp.struct_info.dtype} for input {num}" + ) + return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op) + + +class BitwiseAnd(BitwiseBase): + """Converts an onnx BitwiseAnd node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and) + + +class BitwiseOr(BitwiseBase): + """Converts an onnx BitwiseOr node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or) + + +class BitwiseXor(BitwiseBase): + """Converts an onnx BitwiseXor node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor) class Sigmoid(OnnxOpConverter): @@ -277,6 +426,15 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.softmax(inputs[0], axis=axis) +class LogSoftmax(OnnxOpConverter): + """Converts an onnx LogSoftmax node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + axis = attr.get("axis", -1) + return relax.op.nn.log_softmax(inputs[0], axis=axis) + + class Transpose(OnnxOpConverter): """Converts an onnx Transpose node into an equivalent Relax expression.""" @@ -375,67 +533,6 @@ def is_shape_like(x: Any) -> bool: return relax.op.concat(inputs, axis=axis) -class Add(OnnxOpConverter): - """Convert an onnx Add node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() + inputs[1].data.numpy() - return relax.const(output, output.dtype) - # If primvalues are involved, handle them directly. - if any([isinstance(inp, relax.PrimValue) for inp in inputs]): - x = ( - int(inputs[0].value) - if isinstance(inputs[0], relax.PrimValue) - else inputs[0].data.numpy() - ) - y = ( - int(inputs[1].value) - if isinstance(inputs[1], relax.PrimValue) - else inputs[1].data.numpy() - ) - return relax.PrimValue(int(x + y)) - return relax.op.add(inputs[0], inputs[1]) - - -class Sum(OnnxOpConverter): - """Convert an onnx Sum node into an equivalent Relax expression.""" - - @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - for in_index in range(len(inputs) - 1): - inputs[in_index + 1] = relax.op.add(inputs[in_index], inputs[in_index + 1]) - - return inputs[len(inputs) - 1] - - -class Mul(OnnxOpConverter): - """Convert an onnx Mul node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - # When all inputs are constant, directly multiply. - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() * inputs[1].data.numpy() - return relax.const(output, output.dtype) - # If primvalues are involved, handle them directly. - if any([isinstance(inp, relax.PrimValue) for inp in inputs]): - x = ( - int(inputs[0].value) - if isinstance(inputs[0], relax.PrimValue) - else inputs[0].data.numpy() - ) - y = ( - int(inputs[1].value) - if isinstance(inputs[1], relax.PrimValue) - else inputs[1].data.numpy() - ) - return relax.PrimValue(int(x * y)) - - return relax.op.multiply(inputs[0], inputs[1]) - - class Cast(OnnxOpConverter): """Convert an onnx Cast node into an equivalent Relax expression.""" @@ -482,8 +579,38 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_val = data[np_index] return relax.PrimValue(shape_val) - # TODO(jwfromm) Make relax.take work with other indices shape. - return bb.emit_te(topi.take, data, indices, axis) + return relax.op.take(data, indices, axis) + + +class Scatter(OnnxOpConverter): + """Convert an onnx Scatter node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + raise ValueError("Scatter is deprecated in ONNX 11") + + +class ScatterElements(OnnxOpConverter): + """Convert an onnx ScatterElements node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) + + +class Size(OnnxOpConverter): + """Convert an onnx Size node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + # TODO(tvm-team): add native support for size op + return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0]))) class Gemm(OnnxOpConverter): @@ -542,29 +669,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return out -class Gelu(OnnxOpConverter): - """Operator converter for Gelu from Microsoft onnxruntime contrib opset. - - gelu(x) = 0.5x(1 + erf(x/sqrt(2))) - """ - - @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - return relax.op.nn.gelu(inputs[0]) - - -class BiasGelu(OnnxOpConverter): - """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. - - bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) - """ - - @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - inp = relax.op.add(inputs[0], inputs[1]) - return relax.op.nn.gelu(inp) - - class Where(OnnxOpConverter): """Convert an onnx Where node into an equivalent Relax expression.""" @@ -605,24 +709,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return results -class Equal(OnnxOpConverter): - """Converts an onnx Equal node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() == inputs[1].data.numpy() - return relax.const(output, output.dtype) - elif all([isinstance(inp, (relax.Constant, relax.ShapeExpr)) for inp in inputs]): - lhs = get_prim_expr_list(inputs[0]) - rhs = get_prim_expr_list(inputs[1]) - if len(lhs) != len(rhs): - raise ValueError("Cannot compare two tensors with different shapes") - output = [tvm.ir.structural_equal(l, r) for l, r in zip(lhs, rhs)] - return relax.const(output, "bool") - return relax.op.equal(inputs[0], inputs[1]) - - class Shape(OnnxOpConverter): """Converts an onnx Equal node into an equivalent Relax expression.""" @@ -643,22 +729,6 @@ def _impl_v13(cls, bb, inputs, attr, params): return data_info.shape -class Tanh(OnnxOpConverter): - """Converts an onnx Tanh node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - return relax.op.tanh(inputs[0]) - - -class Sqrt(OnnxOpConverter): - """Converts an onnx Sqrt node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - return relax.op.sqrt(inputs[0]) - - class Trilu(OnnxOpConverter): """Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s) @@ -691,12 +761,157 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.relu(inputs[0]) -class Pow(OnnxOpConverter): - """Converts an onnx Pow node into an equivalent Relax expression.""" +class Elu(OnnxOpConverter): + """Converts an onnx Elu node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - return relax.op.power(inputs[0], inputs[1]) + def _impl_v1(cls, bb, inputs, attr, params): + alpha = float(attr.get("alpha", 1.0)) + return relax.expr.const(-alpha) * relax.op.nn.relu( + relax.expr.const(1.0) - relax.op.exp(inputs[0]) + ) + relax.op.nn.relu(inputs[0]) + + +class Selu(OnnxOpConverter): + """Converts an onnx Selu node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + alpha = attr.get("alpha", 1.67326319217681884765625) + gamma = attr.get("gamma", 1.05070102214813232421875) + return relax.const(gamma) * ( + relax.const(-alpha) * relax.op.nn.relu(relax.const(1.0) - relax.op.exp(inputs[0])) + + relax.op.nn.relu(inputs[0]) + ) + + +class Mish(OnnxOpConverter): + """Converts an onnx Mish node into an equivalent Relax expression. + + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) + """ + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + dtype = inputs[0].checked_type.dtype + return inputs[0] * relax.op.tanh( + relax.op.log(relax.const(1.0, dtype) + relax.op.exp(inputs[0])) + ) + + +class PRelu(OnnxOpConverter): + """Converts an onnx PRelu node into an equivalent Relax expression. + + f(x) = slope * x for x < 0, x for x >= 0 + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + slope = inputs[1] + # TODO(tvm-team): Should add a new op for this. + return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope) + + +class ThresholdedRelu(OnnxOpConverter): + """Converts an onnx ThresholdedRelu node into an equivalent Relax expression. + + f(x) = x for x > alpha, 0 otherwise + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + alpha = attr.get("alpha", 1.0) + return relax.op.greater(x, relax.const(alpha)).astype("float32") * x + + +class LeakyRelu(OnnxOpConverter): + """Converts an onnx LeakyRelu node into an equivalent Relax expression. + + f(x) = x for x > 0, alpha * x otherwise + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + alpha = attr.get("alpha", 0.01) + return relax.op.nn.leakyrelu(x, alpha) + + +class Gelu(OnnxOpConverter): + """Operator converter for Gelu from Microsoft onnxruntime contrib opset. + + gelu(x) = 0.5x(1 + erf(x/sqrt(2))) + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.nn.gelu(inputs[0]) + + +class FastGelu(OnnxOpConverter): + """Operator converter for FastGelu from Microsoft onnxruntime contrib opset. + + fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3))) + = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3))) + = 0.5x(1 + tanh(c1 * x + c2 * x^3))) + , where + c1 = sqrt(2/pi) + c2 = 0.044715 * sqrt(2/pi) + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if inputs[1]: + bias = inputs[1] + bias_shape = bias.struct_info.shape + assert len(bias_shape) == 1, "bias term must be a 1D tensor" + x += bias + + # Declare consts + const_dtype = x.struct_info.dtype + half = relax.const(0.5, dtype=const_dtype) + one = relax.const(1.0, dtype=const_dtype) + const1 = relax.const(math.sqrt(2 / math.pi), dtype=const_dtype) + const2 = relax.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) + + # Compute FastGelu + term1 = relax.op.multiply(half, x) + term2 = relax.op.multiply(const1, x) + term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3, const_dtype))) + tanh = relax.op.tanh(relax.op.add(term2, term3)) + return relax.op.multiply(term1, relax.op.add(one, tanh)) + + +class BiasGelu(OnnxOpConverter): + """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. + + bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) + """ + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + inp = relax.op.add(inputs[0], inputs[1]) + return relax.op.nn.gelu(inp) + + +class Shrink(OnnxOpConverter): + """Converts an onnx Shrink node into an equivalent Relax expression. + + f(x) = x + bias if x > lambd, x - bias if x < -lambd, 0 otherwise + """ + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + x = inputs[0] + dtype = x.struct_info.dtype + lambd = relax.const(attr.get("lambd", 0.5), dtype) + bias = relax.const(attr.get("bias", 0.0), dtype) + zeros = relax.op.zeros_like(x) + return relax.op.where(x > lambd, x - bias, zeros) + relax.op.where( + x < -lambd, x + bias, zeros + ) class Conv(OnnxOpConverter): @@ -730,21 +945,55 @@ def _impl_v11(cls, bb, inputs, attr, params): weight=inputs[1], strides=attr.get("strides", 1), padding=attr.get("pads", 0), - dilation=attr.get("dilation", 1), + dilation=attr.get("dilations", 1), groups=attr.get("group", 1), data_layout=data_layout, kernel_layout=kernel_layout, ) ) if inputs[2] is not None: - bias = relax.op.reshape( - inputs[2], - [1, -1] - + [ - 1, - ] - * (ndim - 2), - ) + bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2)) + conv_out = relax.op.add(conv_out, bias) + + return conv_out + + +class ConvTranspose(OnnxOpConverter): + """Converts an onnx ConvTranspose node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if hasattr(inputs[0].struct_info, "ndim"): + ndim = inputs[0].struct_info.ndim + else: + ndim = len(inputs[0].struct_info.shape) + + if ndim == 3: + op = relax.op.nn.conv1d_transpose + data_layout = "NCW" + kernel_layout = "IOW" + elif ndim == 4: + op = relax.op.nn.conv2d_transpose + data_layout = "NCHW" + kernel_layout = "IOHW" + elif ndim == 5: + raise NotImplementedError("Relax ConvTranspose3d not supported yet") + else: + raise NotImplementedError("Ndim > 5 not supported for convolution.") + + conv_out = op( + data=inputs[0], + weight=inputs[1], + strides=attr.get("strides", 1), + padding=attr.get("pads", 0), + dilation=attr.get("dilations", 1), + groups=attr.get("group", 1), + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + + if inputs[2] is not None: + bias = relax.op.reshape(inputs[2], [1, -1] + [1] * (ndim - 2)) conv_out = relax.op.add(conv_out, bias) return conv_out @@ -839,17 +1088,6 @@ def _impl_v9(cls, bb, inputs, attr, params): return relax.op.broadcast_to(relax.const(value, dtype), shape) -class Sub(OnnxOpConverter): - """Converts an onnx Sub node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = inputs[0].data.numpy() - inputs[1].data.numpy() - return relax.const(output, output.dtype) - return relax.op.subtract(inputs[0], inputs[1]) - - class Sin(OnnxOpConverter): """Converts an onnx Sin node into an equivalent Relax expression.""" @@ -858,6 +1096,14 @@ def _impl_v7(cls, bb, inputs, attr, params): return relax.op.sin(inputs[0]) +class Sinh(OnnxOpConverter): + """Converts an onnx Sinh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.sinh(inputs[0]) + + class Cos(OnnxOpConverter): """Converts an onnx Cos node into an equivalent Relax expression.""" @@ -866,6 +1112,78 @@ def _impl_v7(cls, bb, inputs, attr, params): return relax.op.cos(inputs[0]) +class Cosh(OnnxOpConverter): + """Converts an onnx Cosh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.cosh(inputs[0]) + + +class Tan(OnnxOpConverter): + """Converts an onnx Tan node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.tan(inputs[0]) + + +class Tanh(OnnxOpConverter): + """Converts an onnx Tanh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.tanh(inputs[0]) + + +class Acos(OnnxOpConverter): + """Converts an onnx Acos node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.acos(inputs[0]) + + +class Acosh(OnnxOpConverter): + """Converts an onnx Acosh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.acosh(inputs[0]) + + +class Asin(OnnxOpConverter): + """Converts an onnx Asin node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.asin(inputs[0]) + + +class Asinh(OnnxOpConverter): + """Converts an onnx Asinh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.asinh(inputs[0]) + + +class Atan(OnnxOpConverter): + """Converts an onnx Atan node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr, params): + return relax.op.atan(inputs[0]) + + +class Atanh(OnnxOpConverter): + """Converts an onnx Atanh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.atanh(inputs[0]) + + class Neg(OnnxOpConverter): """Converts an onnx Neg node into an equivalent Relax expression.""" @@ -877,47 +1195,121 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.negative(inputs[0]) -class Abs(OnnxOpConverter): - """Converts an onnx Abs node into an equivalent Relax expression.""" +class Abs(OnnxOpConverter): + """Converts an onnx Abs node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + if isinstance(inputs[0], relax.Constant): + output = _np.abs(inputs[0].data.numpy()) + return relax.const(output, output.dtype) + return relax.op.abs(inputs[0]) + + +class Reciprocal(OnnxOpConverter): + """Converts an onnx Reciprocal node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + input_dtype = inputs[0].struct_info.dtype + return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0]) + + +class Floor(OnnxOpConverter): + """Converts an onnx Floor node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.floor(inputs[0]) + + +class Ceil(OnnxOpConverter): + """Converts an onnx Ceil node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.ceil(inputs[0]) + + +class Round(OnnxOpConverter): + """Converts an onnx Round node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.round(inputs[0]) + + +class IsInf(OnnxOpConverter): + """Converts an onnx IsInf node into an equivalent Relax expression.""" + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + return relax.op.isinf(inputs[0]) + + +class IsNaN(OnnxOpConverter): + """Converts an onnx IsNaN node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if isinstance(inputs[0], relax.Constant): - output = _np.abs(inputs[0].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.abs(inputs[0]) + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.isnan(inputs[0]) -class Min(OnnxOpConverter): - """Converts an onnx Min node into an equivalent Relax expression.""" +class Sqrt(OnnxOpConverter): + """Converts an onnx Sqrt node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): + def _impl_v1(cls, bb, inputs, attr, params): + return relax.op.sqrt(inputs[0]) + + +class MultiInputBase(OnnxOpConverter): + """Converts an onnx MultiInputBase node into an equivalent Relax expression.""" + + numpy_op: Callable = None + relax_op: Callable = None + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if cls.numpy_op is None or cls.relax_op is None: + raise NotImplementedError("numpy_op and relax_op must be defined for MultiInputBase") if all([isinstance(inp, relax.Constant) for inp in inputs]): np_inputs = [inp.data.numpy() for inp in inputs] - output = _np.minimum(*np_inputs) + output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) # Expand inputs, stack them, then perform minimum over the new axis. inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] stacked_tensor = relax.op.concat(inputs, axis=0) - return relax.op.min(stacked_tensor, axis=0) + return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable + + +class Min(MultiInputBase): + """Converts an onnx Min node into an equivalent Relax expression.""" + + numpy_op = _np.min + relax_op = relax.op.min -class Max(OnnxOpConverter): +class Max(MultiInputBase): """Converts an onnx Max node into an equivalent Relax expression.""" - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - np_inputs = [inp.data.numpy() for inp in inputs] - output = _np.maximum(*np_inputs) - return relax.const(output, output.dtype) + numpy_op = _np.max + relax_op = relax.op.max - # Expand inputs, stack them, then perform maximum over the new axis. - inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] - stacked_tensor = relax.op.concat(inputs, axis=0) - return relax.op.max(stacked_tensor, axis=0) + +class Mean(MultiInputBase): + """Converts an onnx Mean node into an equivalent Relax expression.""" + + numpy_op = _np.mean + relax_op = relax.op.mean + + +class Sum(MultiInputBase): + """Converts an onnx Sum node into an equivalent Relax expression.""" + + numpy_op = _np.sum + relax_op = relax.op.sum class Log(OnnxOpConverter): @@ -956,26 +1348,22 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.exp(data) -class Less(OnnxOpConverter): - """Converts an onnx Less node into an equivalent Relax expression.""" +class Softplus(OnnxOpConverter): + """Converts an onnx Softplus node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = _np.less(inputs[0].data.numpy(), inputs[1].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.less(inputs[0], inputs[1]) + def _impl_v1(cls, bb, inputs, attr, params): + dtype = inputs[0].struct_info.dtype + return relax.op.log(relax.op.exp(inputs[0]) + relax.const(1, dtype=dtype)) -class LessOrEqual(OnnxOpConverter): - """Converts an onnx LessOrEqual node into an equivalent Relax expression.""" +class Softsign(OnnxOpConverter): + """Converts an onnx Softsign node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = _np.less_equal(inputs[0].data.numpy(), inputs[1].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.less_equal(inputs[0], inputs[1]) + def _impl_v1(cls, bb, inputs, attr, params): + dtype = inputs[0].struct_info.dtype + return inputs[0] / (relax.op.abs(inputs[0]) + relax.const(1, dtype=dtype)) class Split(OnnxOpConverter): @@ -1456,6 +1844,20 @@ def _impl_v15(cls, bb, inputs, attr, params): ) +class MeanVarianceNormalization(OnnxOpConverter): + """Converts an onnx MeanVarianceNormalization node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + data = inputs[0] + axis = attr.get("axes", (0, 2, 3)) + data_mean = relax.op.mean(data, axis=axis, keepdims=True) + data_mean_squared = relax.op.power(data_mean, relax.const(2, dtype="float32")) + data_squared = relax.op.power(data, relax.const(2, dtype="float32")) + data_squared_mean = relax.op.mean(data_squared, axis=axis, keepdims=True) + return (data - data_mean) / relax.op.sqrt(data_squared_mean - data_mean_squared) + + class Pool(OnnxOpConverter): """A helper class for pool op converters.""" @@ -1557,16 +1959,79 @@ class GlobalAveragePool(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): rank = len(inputs[0].struct_info.shape) - if rank == 3: - return relax.op.nn.adaptive_avg_pool1d(inputs[0], 1) - elif rank == 4: - return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1) - elif rank == 5: - return relax.op.nn.adaptive_avg_pool3d(inputs[0], 1) - raise NotImplementedError( - "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." - % (rank - 2) + axes = list(range(2, rank)) + return relax.op.mean(inputs[0], axis=axes, keepdims=True) + + +class GlobalMaxPool(OnnxOpConverter): + """Converts an onnx GlobalMaxPool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + rank = len(inputs[0].struct_info.shape) + axes = list(range(2, rank)) + return relax.op.max(inputs[0], axis=axes, keepdims=True) + + +class GlobalLpPool(OnnxOpConverter): + """Converts an onnx GlobalLpPool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v2(cls, bb, inputs, attr, params): + p = attr.get("p", 2.0) + dtype = inputs[0].struct_info.dtype + rank = len(inputs[0].struct_info.shape) + axes = list(range(2, rank)) + x_abs = relax.op.abs(inputs[0]) + x_p = relax.op.power(x_abs, relax.const(p, dtype=dtype)) + x_sum = relax.op.sum(x_p, axes, keepdims=True) + return relax.op.power(x_sum, relax.const(1.0 / p, dtype=dtype)) + + +class MaxUnpool(OnnxOpConverter): + """Converts an onnx MaxUnpool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + data = inputs[0] + indices = inputs[1] + output_shape = inputs[2] + kernel_shape = attr.get("kernel_shape") + pads = attr.get("pads", [0] * len(kernel_shape) * 2) + strides = attr.get("strides", [1] * len(kernel_shape)) + + multiplier = _np.concatenate([[1, 1], list(strides)]) + shape = [v.value for v in data.struct_info.shape] + total_output_shape = multiplier * shape + # Add extra dimensions from kernel size and stride mismatch + total_output_shape += _np.concatenate([[0, 0], list(kernel_shape)], axis=0) + total_output_shape -= _np.concatenate([[0, 0], list(strides)], axis=0) + + if output_shape is not None: + total_output_shape = output_shape + + elif pads is not None: + # Get pads in the proper format for relay. + pads = _np.concatenate([[0, 0, 0, 0], list(pads)], axis=0) + pads = _np.reshape(pads, [-1, 2]) + # Compute the total padding per axis. + total_pad = _np.sum(pads, axis=-1) + # Reversing maxpool means that padding actually makes our output smaller. + total_output_shape = total_output_shape - total_pad + + # Create a tensor of zeros then scatter our data through it. + relax_shape = relax.ShapeExpr(total_output_shape.tolist()) + zeros_tensor = bb.emit(relax.op.zeros(relax_shape, data.struct_info.dtype)) + # We need to flatten all our tensors before scattering. + flat_tensor = relax.op.scatter_elements( + relax.op.reshape(zeros_tensor, [-1]), + relax.op.reshape(indices, [-1]), + relax.op.reshape(data, [-1]), + axis=0, ) + # Reshape our flattened data back to normal. + output = relax.op.reshape(flat_tensor, relax_shape) + return output class Flatten(OnnxOpConverter): @@ -1799,6 +2264,32 @@ def _impl_v12(cls, bb, inputs, attr, params): return relax.op.argmin(data, axis, keepdims) +class TopK(OnnxOpConverter): + """Converts an onnx TopK node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + data = inputs[0] + k = inputs[1] + if not isinstance(k, relax.Constant): + raise ValueError("TopK k must be a constant") + k = int(k.data.numpy()) + axis = attr.get("axis", -1) + largest = attr.get("largest", 1) + sorted = attr.get("sorted", 1) + if sorted != 1: + raise ValueError("TopK sorted must be 1 for Relax frontend") + + return relax.op.topk(data, k, axis, ret_type="both", largest=largest) + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + data = inputs[0] + k = attr.get("k", 1) + axis = attr.get("axis", -1) + return relax.op.topk(data, k, axis, ret_type="both") + + class SkipLayerNormalization(OnnxOpConverter): """Converts a microsoft contrib SkipLayerNormalization node into a Relax expression.""" @@ -1871,26 +2362,6 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.Tuple([ln, mask_index]) -class Greater(OnnxOpConverter): - """Converts an onnx Greater node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - if all([isinstance(inp, relax.Constant) for inp in inputs]): - output = _np.greater(inputs[0].data.numpy(), inputs[1].data.numpy()) - return relax.const(output, output.dtype) - return relax.op.greater(inputs[0], inputs[1]) - - -class Reciprocal(OnnxOpConverter): - """Converts an onnx Reciprocal node into an equivalent Relax expression.""" - - @classmethod - def _impl_v13(cls, bb, inputs, attr, params): - input_dtype = inputs[0].struct_info.dtype - return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0]) - - class OneHot(OnnxOpConverter): """Converts an onnx OneHot node into an equivalent Relax expression.""" @@ -1909,15 +2380,16 @@ def _impl_v11(cls, bb, inputs, attr, params): return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype) -class Elu(OnnxOpConverter): - """Converts an onnx Elu node into an equivalent Relax expression.""" +class Unique(OnnxOpConverter): + """Converts an onnx Unique node into an equivalent Relax expression.""" @classmethod - def _impl_v1(cls, bb, inputs, attr, params): - alpha = float(attr.get("alpha", 1.0)) - return relax.expr.const(-alpha) * relax.op.nn.relu( - relax.expr.const(1.0) - relax.op.exp(inputs[0]) - ) + relax.op.nn.relu(inputs[0]) + def _impl_v11(cls, bb, inputs, attr, params): + data = inputs[0] + axis = attr.get("axis", None) + sorted = bool(attr.get("sorted", 1)) + # TODO(tvm-team): Add support for return_index, return_inverse, return_counts + return relax.op.unique(data, sorted=sorted, axis=axis) class HardSigmoid(OnnxOpConverter): @@ -1966,53 +2438,308 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.logical_not(inputs[0]) +class DepthToSpace(OnnxOpConverter): + """Converts an onnx DepthToSpace node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + block_size = int(attr["blocksize"]) + mode = attr.get("mode", b"DCR").decode("utf-8") + b, c, h, w = inputs[0].struct_info.shape + if mode == "DCR": + x = relax.op.reshape( + inputs[0], (b, block_size, block_size, c // (block_size**2), h, w) + ) + x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) + return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) + elif mode == "CRD": + x = relax.op.reshape( + inputs[0], (b, c // (block_size**2), block_size, block_size, h, w) + ) + x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3]) + return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) + else: + raise ValueError(f"Unsupported mode: {mode}, expected DCR or CRD") + + +class SpaceToDepth(OnnxOpConverter): + """Converts an onnx SpaceToDepth node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + block_size = int(attr["blocksize"]) + b, c, h, w = inputs[0].struct_info.shape + x = relax.op.reshape( + inputs[0], (b, c, h // block_size, block_size, w // block_size, block_size) + ) + x = relax.op.permute_dims(x, [0, 3, 5, 1, 2, 4]) + return relax.op.reshape( + x, (b, c * block_size * block_size, h // block_size, w // block_size) + ) + + +class SequenceConstruct(OnnxOpConverter): + """Operator converter for sequence construction op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Construct a tuple from input tensors. + return relax.Tuple(inputs) + + +class SequenceEmpty(OnnxOpConverter): + """Operator converter for sequence empty op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Construct an empty tuple. + return relax.Tuple([]) + + +class SequenceErase(OnnxOpConverter): + """Operator converter for sequence erase op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Erase tensor from sequence on specified position + input_sequence = inputs[0] + + if len(inputs) == 2: + position = inputs[1] + # Non constant position is not supported. + if isinstance(position, relax.Constant): + position = int(position.data.numpy()) + else: + raise NotImplementedError("Position must be a constant.") + else: + position = -1 + + seq_len = len(input_sequence) + if not -seq_len <= position < seq_len: + raise ValueError( + f"Position is out of bounds, expected [-{seq_len}, {seq_len}), got {position}" + ) + + if position < 0: + position = seq_len + position + # Convert sequence to a list, insert tensors before erased, and repackage as Tuple. + tensor_list = [input_sequence[i] for i in range(seq_len) if i != position] + # Create new tuple and return. + return relax.Tuple(tensor_list) + + +class SequenceInsert(OnnxOpConverter): + """Operator converter for sequence insert op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Insert a new tensor into a tuple of tensors. + input_sequence = inputs[0] + tensor_to_insert = inputs[1] + + if len(inputs) == 3: + position = inputs[2] + # Non constant position is not supported. + if isinstance(position, relax.Constant): + position = position.data.numpy() + else: + raise NotImplementedError("Position must be a constant.") + else: + position = -1 + + if position < 0: + position = len(input_sequence) + position + 1 + # Convert sequence to a list, insert new tensor, and repackage as Tuple. + tensor_list = [input_sequence[i] for i in range(len(input_sequence))] + # Insert new tensor. + tensor_list.insert(position, tensor_to_insert) + # Create new tuple and return. + return relax.Tuple(tensor_list) + + +class SequenceLength(OnnxOpConverter): + """Operator converter for sequence length op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + # Get length of input sequence + return relax.const(len(inputs[0]), dtype="int64") + + +class ConcatFromSequence(OnnxOpConverter): + """Operator converter for sequence concatenation op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + new_axis = attr.get("new_axis", 0) + + if new_axis == 1: + raise NotImplementedError("Insert new axis is not supported yet.") + + return relax.op.concat(inputs[0], axis=axis) + + +class SplitToSequence(OnnxOpConverter): + """Operator converter for split to sequence op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", 1) + + input_tensor = inputs[0] + input_shape = input_tensor.struct_info.shape + + # If split is not provided, we split all values along axis. + if len(inputs) == 1: + split = _np.array(1) + if not keepdims: + raise NotImplementedError("Only keepdims=1 is supported for now") + else: + split = inputs[1] + if not isinstance(split, relax.Constant): + raise ValueError("Only constant split supported for SplitToSequence") + split = split.data.numpy() + + if len(split.shape) == 1 and split.shape[0] > 1: + split = _np.cumsum(split) + split = list(split[:-1]) + else: + chunk_size, dim_size = int(split), input_shape[axis] + if dim_size % chunk_size != 0: + raise ValueError( + f"Dimension of size {dim_size} along axis {axis} must be " + f"evenly divisible by chunk size {chunk_size}" + ) + split = dim_size // chunk_size + + output = relax.op.split(input_tensor, split, axis=axis) + return output + + +class SequenceAt(OnnxOpConverter): + """Operator converter for sequence at op.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + input_sequence = inputs[0] + position = inputs[1] + assert isinstance( + position, relax.Constant + ), "Only constant position supported for SequenceAt" + position = int(position.data.numpy()) + return input_sequence[position] + + def _get_convert_map(): return { - "MatMul": MatMul, - "Concat": Concat, + # defs/experimental + # "Optional": Optional_, + # "OptionalHasElement": OptionalHasElement, + # "OptionalGetElement": OptionalGetElement, + # Binary operators "Add": Add, + "Sub": Sub, "Mul": Mul, - "Cast": Cast, + "Div": Div, + # "Mod": Mod, + "Less": Less, + "LessOrEqual": LessOrEqual, + "Greater": Greater, + "GreaterOrEqual": GreaterOrEqual, + "Equal": Equal, + "BitwiseAnd": BitwiseAnd, + "BitwiseOr": BitwiseOr, + "BitwiseXor": BitwiseXor, + # "BitwiseNot": BitwiseNot, + # "BitwiseShift": BitwiseShift, + "And": And, + "Or": Or, + "Xor": Xor, + "Not": Not, + # Unary operators + "Log": Log, + "Exp": Exp, + "Acos": Acos, + "Acosh": Acosh, + "Asin": Asin, + "Asinh": Asinh, + "Atan": Atan, + "Atanh": Atanh, + "Cos": Cos, + "Cosh": Cosh, + "Sin": Sin, + "Sinh": Sinh, + "Tan": Tan, + "Tanh": Tanh, + "Neg": Neg, + "Abs": Abs, + "Reciprocal": Reciprocal, + "Floor": Floor, + "Ceil": Ceil, + "Round": Round, + "IsInf": IsInf, + "IsNaN": IsNaN, + "Sqrt": Sqrt, + "Relu": Relu, + "Selu": Selu, + "Mish": Mish, + "Trilu": Trilu, + "PRelu": PRelu, + "LeakyRelu": LeakyRelu, + "ThresholdedRelu": ThresholdedRelu, + "Elu": Elu, + "Gelu": Gelu, + "FastGelu": FastGelu, + "BiasGelu": BiasGelu, + "HardSigmoid": HardSigmoid, + "HardSwish": HardSwish, + "Sign": Sign, + "Softplus": Softplus, + "Softsign": Softsign, + "Shrink": Shrink, + "Erf": Erf, "Sum": Sum, - "Gather": Gather, + "Min": Min, + "Max": Max, + "Mean": Mean, + "Cast": Cast, "Gemm": Gemm, + "MatMul": MatMul, + # "MatMulInteger": MatMulInteger, + # "MatMulInteger16": MatMulInteger16, "Reshape": Reshape, - "Div": Div, "Sigmoid": Sigmoid, "Softmax": Softmax, + "LogSoftmax": LogSoftmax, + # "Hardmax": Hardmax, "Transpose": Transpose, "Unsqueeze": Unsqueeze, - "Gelu": Gelu, - "BiasGelu": BiasGelu, "Where": Where, + "Concat": Concat, "Clip": Clip, - "Equal": Equal, "Shape": Shape, - "Tanh": Tanh, - "Sqrt": Sqrt, - "Trilu": Trilu, - "Relu": Relu, - "Conv": Conv, "Pow": Pow, - "Erf": Erf, "CumSum": CumSum, "Squeeze": Squeeze, "Constant": Constant, - "Sub": Sub, - "Sin": Sin, - "Cos": Cos, - "Neg": Neg, - "Abs": Abs, - "Min": Min, - "Max": Max, - "Log": Log, - "Exp": Exp, - "Less": Less, - "LessOrEqual": LessOrEqual, + "Gather": Gather, + # "GatherElements": GatherElements, + # "GatherND": GatherND, + "Scatter": Scatter, + "ScatterElements": ScatterElements, + # "ScatterND": ScatterND, + # "Compress": Compress, + "Size": Size, + # "EyeLike": EyeLike, + # Normalization + "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, "SkipLayerNormalization": SkipLayerNormalization, "EmbedLayerNormalization": EmbedLayerNormalization, "InstanceNormalization": InstanceNormalization, + "MeanVarianceNormalization": MeanVarianceNormalization, # defs/reduction "ReduceMax": ReduceMax, "ReduceMin": ReduceMin, @@ -2026,6 +2753,7 @@ def _get_convert_map(): "ReduceL2": ReduceL2, "ArgMax": ArgMax, "ArgMin": ArgMin, + "TopK": TopK, "Expand": Expand, "ConstantOfShape": ConstantOfShape, "Slice": Slice, @@ -2033,23 +2761,42 @@ def _get_convert_map(): "Pad": Pad, "Split": Split, "Tile": Tile, - "BatchNormalization": BatchNormalization, - "MaxPool": MaxPool, "AveragePool": AveragePool, + "MaxPool": MaxPool, + # "LpPool": LpPool, "GlobalAveragePool": GlobalAveragePool, + "GlobalMaxPool": GlobalMaxPool, + "GlobalLpPool": GlobalLpPool, + "MaxUnpool": MaxUnpool, + "Conv": Conv, + "ConvTranspose": ConvTranspose, "Flatten": Flatten, "Identity": Identity, "Resize": Resize, "Einsum": Einsum, "Range": Range, - "Greater": Greater, - "Reciprocal": Reciprocal, "OneHot": OneHot, - "Elu": Elu, - "HardSigmoid": HardSigmoid, - "HardSwish": HardSwish, - "Sign": Sign, - "Not": Not, + "Unique": Unique, + # "NonZero": NonZero, + # "If": If, + # "LRN": LRN, + # "MaxRoiPool": MaxRoiPool, + # "RoiAlign": RoiAlign, + # "NonMaxSuppression": NonMaxSuppression, + # "GridSample": GridSample, + # "Upsample": Upsample, + # others + "DepthToSpace": DepthToSpace, + "SpaceToDepth": SpaceToDepth, + # Sequence operators + "SequenceConstruct": SequenceConstruct, + "SequenceEmpty": SequenceEmpty, + "SequenceErase": SequenceErase, + "SequenceInsert": SequenceInsert, + "SequenceLength": SequenceLength, + "ConcatFromSequence": ConcatFromSequence, + "SplitToSequence": SplitToSequence, + "SequenceAt": SequenceAt, } @@ -2269,6 +3016,14 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Where", "Cast", ] + return_tuple_ops = [ + "SequenceConstruct", + "SequenceEmpty", + "SequenceErase", + "SequenceInsert", + "ConcatFromSequence", + "SplitToSequence", + ] for i, inp in enumerate(inputs): if ( inp is not None @@ -2277,11 +3032,17 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): and op_name not in shape_compatible_ops ): raise ValueError(f"Node {node.name} cannot handle ShapeExpr inputs.") - op = self._convert_operator(op_name, inputs, attr, self.opset) - # Create struct information for the new operator. - op = self.bb.normalize(op) - - if not isinstance(op, relax.Tuple): + try: + op = self._convert_operator(op_name, inputs, attr, self.opset) + # Create struct information for the new operator. + op = self.bb.normalize(op) + except TVMError as err: + print(f"Error converting operator {op_name}, with inputs: {inputs}") + raise err + + if op_name in return_tuple_ops: + outputs_num = 1 + elif not isinstance(op, relax.Tuple): if isinstance(op.checked_type, tvm.ir.type.TupleType): # This is a var bound to a tuple. We need to unpack it and create # a new tuple. @@ -2299,7 +3060,6 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): ), "Missing outputs during conversion. Expected {} but Got {} in {}.".format( len(outputs), outputs_num, op_name ) - if outputs_num == 1: self._nodes[outputs[0]] = op else: @@ -2346,10 +3106,10 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> Dict[str, def _convert_operator( self, op_name: str, - inputs: List[relax.Function], + inputs: List[relax.Expr], attrs: Dict, opset: int, - ) -> relax.Function: + ) -> relax.Expr: """Convert ONNX operator into a Relax operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. @@ -2386,7 +3146,7 @@ def from_onnx( opset: int = None, keep_params_in_input: bool = False, sanitize_input_names: bool = True, -) -> Tuple[IRModule, Dict]: +) -> IRModule: """Convert a ONNX model into an equivalent Relax Function. ONNX graphs are represented as Python Protobuf objects. @@ -2413,8 +3173,6 @@ def from_onnx( ------- mod : tvm.IRModule The relax module for compilation - params : dict of str to tvm.nd.NDArray - The parameter dict to be used by relax """ # Error if the model version is below 1.1.0 if model.ir_version < 3: diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 4d106ad6d23c..0b86e19ce53f 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -77,7 +77,7 @@ def unique( return_inverse = PrimValue(return_inverse) if isinstance(return_counts, bool): return_counts = PrimValue(return_counts) - if axis and isinstance(axis, int): + if axis is not None and isinstance(axis, int): axis = PrimValue(axis) return _ffi_api.unique( # type: ignore x, sorted, return_index, return_inverse, return_counts, axis @@ -91,6 +91,7 @@ def numpy_unique( return_index: int, return_inverse: int, return_counts: int, + axis: Optional[int] = None, ) -> tvm.nd.array: """Returns the unique elements of the input tensor. @@ -103,8 +104,9 @@ def numpy_unique( raise NotImplementedError("missing support return_inverse or return_counts set to true") x_numpy = x.numpy() # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. - output_sorted_numpy, indices = np.unique(x_numpy, return_index=True) + output_sorted_numpy, indices = np.unique(x_numpy, return_index=True, axis=axis) + if sorted: return tvm.nd.array(output_sorted_numpy) - output_numpy = [x_numpy.flatten()[index] for index in builtins.sorted(indices, reverse=True)] + output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) return tvm.nd.array(output_numpy) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 809d231fd30d..8317d4504e1e 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -171,21 +171,16 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr: "and thus cannot be legalized by TOPI" ) return call - if call.attrs.groups != 1: - logging.info( - "TOPI conv1d_transpose does not support groups other than 1, " - "and thus cannot be legalized by TOPI" - ) - return call return bb.call_te( - topi.nn.conv1d_transpose_ncw, + topi.nn.group_conv1d_transpose_ncw, call.args[0], call.args[1], stride=call.attrs.strides, padding=call.attrs.padding, out_dtype=call.struct_info.dtype, output_padding=call.attrs.output_padding, + groups=call.attrs.groups, primfunc_name_hint="conv1d_transpose", ) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0e7cfbd7c093..2837ad2185e9 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -21,7 +21,7 @@ This file is a test script to test Relax ONNX frontend coverage. """ -from typing import Dict, Optional +from typing import Dict, List, Literal, Optional import numpy as np import onnx @@ -118,6 +118,7 @@ def check_correctness( tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) + print(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) @@ -137,25 +138,31 @@ def check_correctness( vm.invoke_stateful("main") tvm_output = vm.get_outputs("main") # Wrap as a list if there is only one output. - if isinstance(tvm_output, tvm.nd.NDArray): + if len(ort_output) == 1: + # Do not check the output number for TVM + # As for sequence output, the TVM output is a Tuple + # while the ONNX output number is one, which is a list tvm_output = [tvm_output] - # If the output is a shape tuple, convert it to an ndarray for comparison. - if isinstance(tvm_output, tvm.runtime.ShapeTuple): - tvm_output = [tvm.nd.array([int(i) for i in tvm_output])] - tvm_num_outputs = len(tvm_output) - # Shape tuples need to be handled specially. - if isinstance(tvm_output, tvm.runtime.ShapeTuple): - tvm_num_outputs = 1 + def _check_output(tvm_out, ort_out): + if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)): + assert len(tvm_out) == len(ort_out), "Unequal number of outputs" + for tvm_out_i, ort_out_i in zip(tvm_out, ort_out): + _check_output(tvm_out_i, ort_out_i) + elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray): + tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) + elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray): + shape_out = tvm.nd.array([int(i) for i in tvm_out]) + tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) + else: + raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}") # Check that number of outputs match. - assert tvm_num_outputs == len(ort_output), "Unequal number of outputs" - + assert len(tvm_output) == len(ort_output), "Unequal number of outputs" for (tvm_out, ort_out) in zip(tvm_output, ort_output): # TODO Allow configurable tolerance. - # Sometimes None is used to indicate an unused output. if ort_out is not None: - tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) + _check_output(tvm_out, ort_out) @pytest.mark.parametrize( @@ -187,35 +194,61 @@ def test_sanitize(input_names, expected_names): assert param.name_hint == expected_names[i] -def verify_unary(op_name, shape, attrs={}, domain=None, dtype=TensorProto.FLOAT): +def verify_unary( + op_name, + shape, + attrs={}, + domain=None, + input_dtype=TensorProto.FLOAT, + output_dtype=TensorProto.FLOAT, + opset=14, +): test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain) graph = helper.make_graph( [test_node], "elemwise_test", inputs=[ - helper.make_tensor_value_info("x", dtype, shape), + helper.make_tensor_value_info("x", input_dtype, shape), ], - outputs=[helper.make_tensor_value_info("y", dtype, shape)], + outputs=[helper.make_tensor_value_info("y", output_dtype, shape)], ) model = helper.make_model(graph, producer_name="elemwise_test") - check_correctness(model) + check_correctness(model, opset=opset) -def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None): +def verify_binary( + op_name, shape_a, shape_b, shape_c, attrs={}, domain=None, dtype=TensorProto.FLOAT, opset=14 +): test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain) graph = helper.make_graph( [test_node], "binary_test", inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a), - helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b), + helper.make_tensor_value_info("a", dtype, shape_a), + helper.make_tensor_value_info("b", dtype, shape_b), ], - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c)], + outputs=[helper.make_tensor_value_info("c", dtype, shape_c)], ) model = helper.make_model(graph, producer_name="binary_test") - check_correctness(model) + check_correctness(model, opset=opset) + + +def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32, opset=14): + a = make_constant_node("a", dtype, [], [4]) + b = make_constant_node("b", dtype, [], [8]) + test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain) + graph = helper.make_graph( + [a, b, test_node], + "binary_test", + inputs=[], + outputs=[helper.make_tensor_value_info("c", dtype, ())], + ) + + model = helper.make_model(graph, producer_name="binary_test") + # NOTE: explicitly pass inputs to avoid numerical error + check_correctness(model, opset=opset) def verify_compare(op_name, shape, attrs={}, domain=None): @@ -289,16 +322,95 @@ def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) -def test_add(): - verify_binary("Add", [1, 32], [1, 32], [1, 32]) +@pytest.mark.parametrize("op_name", ["Add", "Sub", "Mul", "Div", "Pow"]) +def test_binary(op_name: str): + verify_binary(op_name, [1, 32], [1, 32], [1, 32]) + verify_binary_scalar(op_name) + + +@pytest.mark.parametrize("num_inputs", [1, 2, 4]) +@pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) +def test_multi_input(op_name: str, num_inputs: int): + input_shape = [32, 32] + input_var = ["i" + str(i) for i in range(num_inputs)] + input_values = [ + helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for var in input_var + ] + test_node = helper.make_node(op_name, input_var, ["c"]) + graph = helper.make_graph( + [test_node], + "multi_input_test", + inputs=input_values, + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, input_shape)], + ) + + model = helper.make_model(graph, producer_name="multi_input_test") + check_correctness(model) -def test_mul(): - verify_binary("Mul", [1, 32], [1, 32], [1, 32]) +@pytest.mark.parametrize("op_name", ["Less", "LessOrEqual", "Greater", "GreaterOrEqual"]) +def test_compare(op_name: str): + verify_compare(op_name, [1, 32]) -def test_sum(): - verify_binary("Sum", [1, 32], [1, 32], [1, 32]) +@pytest.mark.parametrize("op_name", ["And", "Or", "Xor"]) +def test_binary_bool(op_name: str): + verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.BOOL) + + +@pytest.mark.parametrize( + "op_name", + [ + "Sin", + "Cos", + "Tan", + "Sinh", + "Cosh", + "Tanh", + "Asin", + "Acos", + "Atan", + "Asinh", + "Acosh", + "Atanh", + "Neg", + "Abs", + "Log", + "Exp", + "Not", + "Reciprocal", + "Floor", + "Ceil", + "Round", + "IsInf", + "IsNaN", + "Sqrt", + "Relu", + "Elu", + "HardSwish", + "Sign", + "Softplus", + "Softsign", + "Erf", + "Sigmoid", + "Softmax", + "LogSoftmax", + "Identity", + ], +) +def test_unary(op_name: str): + input_dtype = TensorProto.FLOAT + if op_name in [ + "IsNaN", + "IsInf", + ]: + pytest.skip(f"Skipping test {op_name} because current LegalizeOps does not support it.") + elif op_name == "Not": + input_dtype = TensorProto.BOOL + output_dtype = TensorProto.BOOL + else: + output_dtype = TensorProto.FLOAT + verify_unary(op_name, [32, 32], input_dtype=input_dtype, output_dtype=output_dtype) @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -350,6 +462,44 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1) +@pytest.mark.parametrize("axis", [0, 1, 2]) +@pytest.mark.parametrize(("name", "opset"), [("Scatter", 10), ("ScatterElements", 11)]) +def test_scatter(axis: int, name: str, opset: int): + if axis != 1: + pytest.skip("The current topi impl is wrong, which only works for axis=1") + input_shape = [16, 16, 16] + indices_shape = [8, 8, 8] + updates_shape = [8, 8, 8] + output_shape = [16, 16, 16] + node = helper.make_node(name, ["data", "indices", "updates"], ["output"], axis=axis) + graph = helper.make_graph( + [node], + "scatter_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="scatter_test") + indices = np.random.randint(0, 16, indices_shape) + check_correctness(model, inputs={"indices": indices}, opset=opset) + + +def test_size(): + test_node = helper.make_node("Size", ["x"], ["y"]) + graph = helper.make_graph( + [test_node], + "size_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 3, 3])], + outputs=[helper.make_tensor_value_info("y", TensorProto.INT64, [3])], + ) + + model = helper.make_model(graph, producer_name="size_test") + check_correctness(model) + + @pytest.mark.parametrize("alpha", [None, 0.25, 1.0]) @pytest.mark.parametrize("beta", [None, 0.35, 1.0]) @pytest.mark.parametrize("useC", [False, True]) @@ -408,18 +558,6 @@ def test_reshape(in_shape, shape, out_shape): check_correctness(model, inputs=input_values) -def test_div(): - verify_binary("Div", [32, 32], [32, 32], [32, 32]) - - -def test_sigmoid(): - verify_unary("Sigmoid", [32, 32]) - - -def test_softmax(): - verify_unary("Softmax", [32, 32, 32]) - - def test_transpose(): verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]}) @@ -567,28 +705,33 @@ def test_shape(): check_correctness(model) -def test_tanh(): - verify_unary("Tanh", [9, 8, 7, 6]) +@pytest.mark.parametrize("upper", [True, False]) +def test_trilu(upper: bool): + verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper}) -def test_sqrt(): - verify_unary("Sqrt", [32, 32]) +def test_selu(): + verify_unary("Selu", [3, 32, 32]) + verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3}) -def test_relu(): - verify_unary("Relu", [32, 32]) +@pytest.mark.skip(reason="opset 18 is not supported in CI") +def test_mish(): + verify_unary("Mish", [3, 32, 32], opset=18) -def test_tril(): - verify_unary("Trilu", [3, 5, 5], attrs={"upper": False}) +def test_prelu(): + verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32]) -def test_triu(): - verify_unary("Trilu", [3, 5, 5], attrs={"upper": True}) +def test_thresholded_relu(): + verify_unary("ThresholdedRelu", [3, 32, 32]) + verify_unary("ThresholdedRelu", [3, 32, 32], attrs={"alpha": -0.01}) -def test_elu(): - verify_unary("Elu", [32, 32]) +def test_leakyrelu(): + verify_unary("LeakyRelu", [32, 32]) + verify_unary("LeakyRelu", [32, 32], attrs={"alpha": 0.2}) def test_hardsigmoid(): @@ -597,30 +740,40 @@ def test_hardsigmoid(): verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6}) -def test_hardswish(): - verify_unary("HardSwish", [32, 32]) - - -def test_sign(): - verify_unary("Sign", [32, 32]) - - -def test_not(): - verify_unary("Not", [32, 32], dtype=TensorProto.BOOL) +def test_shrink(): + verify_unary("Shrink", [32, 32]) + verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1}) -def test_conv(): - def _verify_conv(input_shape, weight_shape, output_shape): +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1, 2]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("pad", [0, 2]) +def test_conv(stride: int, dilation: int, pad: int, bias: bool): + def _verify_conv(input_shape, weight_shape): + nd = len(weight_shape) - 2 + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1 + for i in range(2, len(input_shape)) + ] bias_shape = [output_shape[1]] - conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"]) + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + pads=[pad] * nd * 2, + group=input_shape[1] // weight_shape[1], + ) graph = helper.make_graph( [conv_node], "conv_test", inputs=[ helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), - helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape), - ], + ] + + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape)] if bias else []), outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], ) @@ -628,20 +781,61 @@ def _verify_conv(input_shape, weight_shape, output_shape): check_correctness(model, atol=1e-4) # Conv1D - _verify_conv([3, 12, 32], [4, 12, 3], [3, 4, 30]) + _verify_conv([3, 4, 32], [4, 4, 3]) + _verify_conv([3, 4, 32], [2, 4, 3]) # group=2 # Conv2D - _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30]) + _verify_conv([3, 4, 32, 32], [4, 4, 3, 3]) + _verify_conv([3, 4, 32, 32], [2, 4, 3, 3]) # group=2 # Conv3D - _verify_conv([3, 12, 32, 32, 32], [4, 12, 3, 3, 3], [3, 4, 30, 30, 30]) + _verify_conv([3, 4, 32, 32, 32], [4, 4, 3, 3, 3]) + _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2 + + +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("pad", [0, 2]) +def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool): + def _verify_conv_transpose(input_shape, weight_shape): + nd = len(weight_shape) - 2 + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] - 1) * stride - 2 * pad + dilation * (weight_shape[i] - 1) + 1 + for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "ConvTranspose", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + pads=[pad] * nd * 2, + group=input_shape[1] // weight_shape[1], + ) + graph = helper.make_graph( + [conv_node], + "conv_transpose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), + ] + + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape)] if bias else []), + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="conv_transpose_test") + check_correctness(model, atol=1e-4) -def test_pow(): - verify_binary("Pow", [32, 32], [32, 32], [32, 32]) + # ConvTranspose1D + _verify_conv_transpose([3, 4, 32], [4, 4, 3]) + _verify_conv_transpose([3, 4, 32], [4, 2, 3]) # group=2 + # ConvTranspose2D + _verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3]) + _verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3]) # group=2 -def test_erf(): - verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT) - verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT16) +def test_pow(): + verify_binary("Pow", [32, 32], [32, 32], [32, 32]) @pytest.mark.parametrize("reverse", [False]) @@ -712,46 +906,6 @@ def test_const(): check_correctness(model) -def test_sub(): - verify_binary("Sub", [32, 16], [32, 16], [32, 16]) - - -def test_min(): - verify_binary("Min", [32, 16], [32, 16], [32, 16]) - - -def test_max(): - verify_binary("Max", [32, 16], [32, 16], [32, 16]) - - -def test_sin(): - verify_unary("Sin", [32, 16]) - - -def test_cos(): - verify_unary("Cos", [32, 16]) - - -def test_identity(): - verify_unary("Identity", [32, 16]) - - -def test_neg(): - verify_unary("Neg", [32, 16]) - - -def test_abs(): - verify_unary("Abs", [32, 16]) - - -def test_log(): - verify_unary("Log", [32, 16]) - - -def test_exp(): - verify_unary("Exp", [32, 16]) - - def test_instance_norm(): verify_ternary( "InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32], attrs={"epsilon": 1e-12} @@ -761,6 +915,11 @@ def test_instance_norm(): ) +def test_mean_variance_norm(): + verify_unary("MeanVarianceNormalization", [1, 3, 32, 32]) + verify_unary("MeanVarianceNormalization", [1, 3, 32, 32], attrs={"axes": (1, 2, 3)}) + + def test_layer_norm(): layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"], ["d"], epsilon=1e-12) @@ -1075,9 +1234,36 @@ def verify_arg_min_max(input_dim, in_dtype, op_name="ArgMax", axis=None, keepdim verify_arg_min_max([3, 4, 4], in_dtype, "ArgMin", axis, keepdims) +@pytest.mark.parametrize("axis", [-1, 0, 1]) +@pytest.mark.parametrize("largest", [True, False]) +def test_topk(axis: int, largest: int): + in_shape = [32, 32, 32] + k_value = 4 + out_shape = in_shape + out_shape[axis] = k_value + k = make_constant_node("k", TensorProto.INT64, [1], [k_value]) + node = onnx.helper.make_node( + "TopK", + inputs=["data", "k"], + outputs=["values", "indices"], + axis=axis, + largest=largest, + ) + graph = helper.make_graph( + [k, node], + "topk_test", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, in_shape)], + outputs=[ + helper.make_tensor_value_info("values", TensorProto.FLOAT, out_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, out_shape), + ], + ) + model = helper.make_model(graph, producer_name="topk_test") + + check_correctness(model) + + @pytest.mark.parametrize("dynamic", [False, True]) -# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. -@pytest.mark.skip("Produces ill-formed IR") def test_expand(dynamic): if dynamic: # TODO: Support dynamic shape for Expand @@ -1586,14 +1772,6 @@ def test_range(): check_correctness(model) -def test_less(): - verify_compare("Less", [32, 32]) - - -def test_less_equal(): - verify_compare("LessOrEqual", [32, 32]) - - def test_batch_norm(): batch_norm_node = helper.make_node( "BatchNormalization", ["x", "s", "bias", "mean", "var"], ["y"], epsilon=1e-2 @@ -1811,17 +1989,58 @@ def test_global_average_pool(): verify_unary("GlobalAveragePool", [1, 3, 32, 32, 32]) +def test_global_max_pool(): + verify_unary("GlobalMaxPool", [1, 3, 32]) + verify_unary("GlobalMaxPool", [1, 3, 32, 32]) + verify_unary("GlobalMaxPool", [1, 3, 32, 32, 32]) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_global_lp_pool(p: int): + verify_unary("GlobalLpPool", [1, 3, 32], attrs={"p": p}) + verify_unary("GlobalLpPool", [1, 3, 32, 32], attrs={"p": p}) + verify_unary("GlobalLpPool", [1, 3, 32, 32, 32], attrs={"p": p}) + + +@pytest.mark.parametrize("kernel_shape", [[2, 2], [3, 3]]) +@pytest.mark.parametrize("pads", [None, [1, 1, 1, 1]]) +@pytest.mark.parametrize("strides", [None, [2, 2]]) +def test_maxunpool(kernel_shape, pads, strides): + input_shape = [16, 3, 16, 16] + input_names = ["X", "I"] + input_info = [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("I", TensorProto.INT64, input_shape), + ] + + attrs = {"kernel_shape": kernel_shape} + if pads is not None: + attrs["pads"] = pads + if strides is not None: + attrs["strides"] = strides + + node = helper.make_node("MaxUnpool", inputs=input_names, outputs=["y"], **attrs) + + graph = helper.make_graph( + [node], + "maxunpool_test", + inputs=input_info, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)], + ) + + max_random = int(np.prod(np.array(kernel_shape))) + indices = np.random.randint(0, max_random, size=input_shape) + + model = helper.make_model(graph, producer_name="maxunpool_test") + check_correctness(model, inputs={"I": indices}) + + def test_flatten(): verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0}) verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1}) verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 2}) -def test_greater(): - verify_compare("Greater", [32, 32]) - verify_compare("Greater", [64, 16]) - - def test_onehot(): one_hot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["y"], axis=1) graph = helper.make_graph( @@ -1844,8 +2063,189 @@ def test_onehot(): check_correctness(model, inputs=values) -def test_reciprocal(): - verify_unary("Reciprocal", [3, 32, 32]) +@pytest.mark.parametrize("axis", [None, 0, 1, -1]) +@pytest.mark.parametrize("sorted", [0, 1]) +def test_unique(axis: Optional[int], sorted: int): + input_shape = [32, 32] + if axis is None: + output_shape = [-1] + else: + output_shape = [32, 32] + output_shape[axis] = -1 + unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis, sorted=sorted) + graph = helper.make_graph( + [unique_node], + "unique_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape)], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="unique_test") + check_correctness(model) + + +@pytest.mark.parametrize("mode", ["DCR", "CRD"]) +def test_depth_to_space(mode: Literal["DCR", "CRD"]): + in_shape = [1, 8, 2, 3] + out_shape = [1, 2, 4, 6] + blocksize = 2 + node = onnx.helper.make_node( + "DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blocksize, mode=mode + ) + graph = helper.make_graph( + [node], + "depth_to_space_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, in_shape)], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="depth_to_space_test") + + check_correctness(model) + + +def test_space_to_depth(): + in_shape = [1, 2, 4, 6] + out_shape = [1, 8, 2, 3] + blocksize = 2 + node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"], blocksize=blocksize) + graph = helper.make_graph( + [node], + "space_to_depth_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, in_shape)], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="space_to_depth_test") + + check_correctness(model) + + +def construct_sequence(input_shape: List[int], num_tensors: int, name: str = "sequence"): + inputs = [f"data{i}" for i in range(num_tensors)] + sequence_construct_node = helper.make_node("SequenceConstruct", inputs, [name]) + graph_inputs = [ + helper.make_tensor_value_info(f"data{i}", TensorProto.FLOAT, input_shape) + for i in range(num_tensors) + ] + return sequence_construct_node, graph_inputs + + +def make_constant_node(name: str, data_type: int, dims: List[int], vals: List[int]): + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=data_type, dims=dims, vals=vals), + ) + + +def test_sequence_construct(): + node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) + graph = helper.make_graph( + [node], + "test_sequence_construct", + inputs=graph_inputs, + outputs=[helper.make_tensor_sequence_value_info("sequence", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_construct") + check_correctness(model) + + +def test_sequence_empty(): + sequence_empty_node = helper.make_node("SequenceEmpty", [], ["sequence"]) + graph = helper.make_graph( + [sequence_empty_node], + "test_sequence_empty", + inputs=[], + outputs=[helper.make_tensor_sequence_value_info("sequence", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="test_sequence_empty") + check_correctness(model) + + +@pytest.mark.parametrize("explicit_position", [True, False]) +def test_sequence_erase(explicit_position: bool): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + node_input = ["sequence", "index"] if explicit_position else ["sequence"] + sequence_erase_node = helper.make_node("SequenceErase", node_input, ["output"]) + graph = helper.make_graph( + [index, seq_node, sequence_erase_node], + "test_sequence_erase", + inputs=graph_inputs, + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_erase") + check_correctness(model) + + +@pytest.mark.parametrize("explicit_position", [True, False]) +def test_sequence_insert(explicit_position: bool): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [0]) + node_input = ["sequence", "value", "index"] if explicit_position else ["sequence", "value"] + sequence_insert_node = helper.make_node("SequenceInsert", node_input, ["output"]) + graph = helper.make_graph( + [index, seq_node, sequence_insert_node], + "test_sequence_insert", + inputs=[*graph_inputs, helper.make_tensor_value_info("value", TensorProto.FLOAT, [32, 32])], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_insert") + check_correctness(model) + + +@pytest.mark.parametrize("new_axis", [0, 1]) +def test_concat_from_sequence(new_axis: Literal[0, 1]): + if new_axis == 1: + pytest.skip("ConcatFromSequence with new_axis=1 is not supported yet") + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) + concat_from_sequence_node = helper.make_node( + "ConcatFromSequence", ["sequence"], ["output"], axis=1 + ) + graph = helper.make_graph( + [seq_node, concat_from_sequence_node], + "test_concat_from_sequence", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [64, 32])], + ) + model = helper.make_model(graph, producer_name="test_concat_from_sequence") + check_correctness(model) + + +@pytest.mark.parametrize("split", [2, [16, 48]]) +def test_split_to_sequence(split): + split_to_sequence_node = helper.make_node( + "SplitToSequence", + ["data", "split"], + ["output"], + axis=0, + ) + split_shape = [len(split)] if isinstance(split, list) else () + split_node = make_constant_node( + "split", TensorProto.INT64, split_shape, [split] if isinstance(split, int) else split + ) + graph = helper.make_graph( + [split_node, split_to_sequence_node], + "test_split_to_sequence", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 32])], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence") + check_correctness(model) + + +def test_sequence_at(): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + node_input = ["sequence", "index"] + sequence_at_node = helper.make_node("SequenceAt", node_input, ["output"]) + graph = helper.make_graph( + [index, seq_node, sequence_at_node], + "test_sequence_at", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_sequence_at") + check_correctness(model) def test_symbolic_shape_deduction(): diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index fcb8727d8508..a80b988d06c4 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -60,7 +60,7 @@ def test_unique(exec_mode): result, result_sorted = run_cpu(InputModule, "foo", data, exec_mode=exec_mode) expected_output_sorted, indices = np.unique(data_numpy, return_index=True) - expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] + expected_output = [data_numpy.flatten()[index] for index in sorted(indices)] np.testing.assert_array_equal(expected_output_sorted, result_sorted.numpy()) np.testing.assert_array_equal(expected_output, result.numpy()) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index d03d48968d90..12436cf8023f 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -204,6 +204,53 @@ def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1 tvm.ir.assert_structural_equal(mod, Expected) +def test_conv1d_transpose(): + # fmt: off + @I.ir_module + class Conv1dTranspose: + @R.function + def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((128, 16, 3), "float32")): + gv = R.nn.conv1d_transpose(x, w, strides=2, padding=1, dilation=1, output_padding=1, groups=8) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), w: T.Buffer((T.int64(128), T.int64(16), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(55))) + data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58))) + kernel = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(55)): + with T.block("data_dilate"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + data_dilate[v_i0, v_i1, v_i2] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0), x[v_i0, v_i1, v_i2 // T.int64(2)], T.float32(0.0)) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(58)): + with T.block("data_pad"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + data_pad[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56), data_dilate[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0.0)) + for o, i, w_1 in T.grid(T.int64(16), T.int64(128), T.int64(3)): + with T.block("kernel"): + v_o, v_i, v_w = T.axis.remap("SSS", [o, i, w_1]) + kernel[v_o, v_i, v_w] = w[v_i, v_o, T.int64(2) - v_w] + for b, c, w_1, dc, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(16), T.int64(3)): + with T.block("compute"): + v_b, v_c, v_w, v_dc, v_dw = T.axis.remap("SSSRR", [b, c, w_1, dc, dw]) + with T.init(): + compute[v_b, v_c, v_w] = T.float32(0.0) + compute[v_b, v_c, v_w] = compute[v_b, v_c, v_w] + data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_w + v_dw] * kernel[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dw] + + @R.function + def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((128, 16, 3), dtype="float32")) -> R.Tensor((2, 128, 56), dtype="float32"): + cls = Expected + gv = R.call_tir(cls.conv1d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Conv1dTranspose) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_conv2d(): # fmt: off @tvm.script.ir_module