From 6119889ffc7309f499d62021805c083edfe35033 Mon Sep 17 00:00:00 2001 From: Brunno Goldstein Date: Wed, 23 Mar 2022 16:10:40 -0300 Subject: [PATCH 1/6] ONNX Opset 14 - HardSwish Added hardswish support to TVM CI and fixed unit test. - Add class HardSwish and added its reference to convert_map in onnx.py; - Removed test_hardswish entry from test_forward.py; --- python/tvm/relay/frontend/onnx.py | 13 +++++++++++++ tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a751f23fe732..8565423d0ce9 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1841,6 +1841,18 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt("clip")([transformX], attr) +class HardSwish(OnnxOpConverter): + """Operator converter for HardSwish.""" + + @classmethod + def _impl_v14(cls, inputs, attr, params): + alpha = attr.get("alpha", 1/6) + beta = attr.get("beta", 0.5) + transformX = (inputs[0] * _expr.const(alpha) + _expr.const(beta)) + attr = {"a_min": 0, "a_max": 1} + return inputs[0] * AttrCvt("clip")([transformX], attr) + + class Reduce(OnnxOpConverter): """Operator converter for reduce ops.""" @@ -4674,6 +4686,7 @@ def _get_convert_map(opset): "PRelu": Prelu.get_converter(opset), "Sigmoid": Renamer("sigmoid"), "HardSigmoid": HardSigmoid.get_converter(opset), + "HardSwish": HardSwish.get_converter(opset), "Max": Maximum.get_converter(opset), "Min": Minimum.get_converter(opset), "Sum": Sum.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a4631e762f6f..22ff887e8d9e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5061,7 +5061,6 @@ def verify_eyelike(indata): "test_dropout_default_mask_ratio", "test_dropout_default_ratio", "test_gru_batchwise", - "test_hardswish", "test_identity_sequence", "test_if_seq", "test_loop11", From 0adfb266dec3bc0c0371dce60925144d213e72b9 Mon Sep 17 00:00:00 2001 From: Brunno Goldstein Date: Wed, 23 Mar 2022 16:38:17 -0300 Subject: [PATCH 2/6] ONNX Opset 14 Support - HardSwish Fixing onnx.py format. --- python/tvm/relay/frontend/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8565423d0ce9..1622abb99f82 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1846,9 +1846,9 @@ class HardSwish(OnnxOpConverter): @classmethod def _impl_v14(cls, inputs, attr, params): - alpha = attr.get("alpha", 1/6) + alpha = attr.get("alpha", 1 / 6) beta = attr.get("beta", 0.5) - transformX = (inputs[0] * _expr.const(alpha) + _expr.const(beta)) + transformX = inputs[0] * _expr.const(alpha) + _expr.const(beta) attr = {"a_min": 0, "a_max": 1} return inputs[0] * AttrCvt("clip")([transformX], attr) From 1a75e904070896692df28a272c5e915c4e8104cf Mon Sep 17 00:00:00 2001 From: Brunno Goldstein Date: Thu, 24 Mar 2022 07:51:45 -0300 Subject: [PATCH 3/6] jostle ci From 166639d188708f066bb1a0d69e05c932b2cf5626 Mon Sep 17 00:00:00 2001 From: Brunno Goldstein Date: Fri, 1 Apr 2022 14:47:38 -0300 Subject: [PATCH 4/6] [Relay/TOPI][ONNX/TFLite] Refactor MATRIX_SET_DIAG Operator for Relay/TOPI to support ONNX Trilu operator This commit is based on PR #9329 proposed by @shengxinhu. Refactor MATRIX_SET_DIAG operator in Relay/TOPI to support ONNX Trilu operator; + Fixed issues related to shape transformation of inputs in tflite and onnx frontend ops. --- include/tvm/relay/attrs/transform.h | 4 -- include/tvm/topi/transform.h | 22 +++++------ python/tvm/relay/frontend/onnx.py | 32 ++++++++++++++++ python/tvm/relay/frontend/tflite.py | 12 ++++-- python/tvm/relay/op/transform.py | 6 +++ src/relay/op/tensor/transform.cc | 38 +++++++------------ src/topi/transform.cc | 5 +-- tests/python/frontend/onnx/test_forward.py | 18 --------- tests/python/relay/test_op_level10.py | 8 +++- .../python/topi/python/test_topi_transform.py | 30 +++++++++++---- 10 files changed, 102 insertions(+), 73 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d16471b108ca..eb1d79d0e0a6 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -479,14 +479,10 @@ struct OneHotAttrs : public tvm::AttrsNode { /*! \brief Attributes used in matrix_set_diag operator */ struct MatrixSetDiagAttrs : public tvm::AttrsNode { - int k1; - int k2; bool super_diag_right_align; bool sub_diag_right_align; TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") { - TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals."); - TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals."); TVM_ATTR_FIELD(super_diag_right_align) .set_default(true) .describe("Bool, true iff super-diagonal is right aligned (left-padded)."); diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index ef36c015957a..9c69b4500467 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1851,14 +1851,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Arrayshape.size() - 1; - bool only_one_diagonal = k1 == k2; - return compute( input->shape, [&](const Array& iter_vars) { @@ -1868,12 +1867,10 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k for (size_t i = 0; i < ndim - 1; i++) { diagonal_indices.push_back(iter_vars[i]); } - if (only_one_diagonal) { - k = k1; - } else { + auto multi_diagonals = [&]() { // Determining which diagonal/sub-diagonal/super-diagonal it is k = iter_vars[ndim] - iter_vars[ndim - 1]; - diagonal_indices.push_back(k2 - k); + diagonal_indices.push_back(k2(0) - k); // Calculating the offset in diagonal tensor for this diagonal auto get_offset = [&](PrimExpr M, PrimExpr N) { @@ -1886,13 +1883,16 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k : 0, sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k) : 0); - } + return k; + }; + auto get_k = [&]() { return if_then_else(k1(0) == k2(0), k1(0), multi_diagonals()); }; + k = get_k(); diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + offset); return diagonal(diagonal_indices); }; - return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1, - if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2, + return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0), + if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0), get_diag(), input(iter_vars)), input(iter_vars)); }, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 04fb17abbb19..ed7477f0ac28 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4637,6 +4637,37 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Trilu(OnnxOpConverter): + """Operator converter for Trilu""" + + @classmethod + def _impl_v14(cls, inputs, attr, params): + upper = attr.get("upper", 1) + input_shape = shape_of(inputs[0]) + input_dims = infer_shape(input_shape)[0] + data_type = infer_type(inputs[0]).checked_type.dtype + k_tensor = relay.const(np.asarray(0), dtype=np.int64) + if len(inputs) == 2: + k_tensor = inputs[1] + + diag_input = relay.zeros(fold_constant(input_shape), dtype=data_type) + k1, k2 = None, None + if upper == 0: + k1 = relay.add(k_tensor, relay.const(1, dtype="int64")) + k1 = relay.expand_dims(k1, axis=0) + k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32")) + k2 = relay.expand_dims(k2, axis=0) + else: + k1 = relay.take(input_shape, relay.const(input_dims - 2, dtype="int32")) + k1 = relay.multiply(k1, relay.const(-1, dtype="int64")) + k1 = relay.subtract(k1, relay.const(1, dtype="int64")) + k1 = relay.expand_dims(k1, axis=0) + k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64")) + k2 = relay.expand_dims(k2, axis=0) + + return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2)) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -4810,6 +4841,7 @@ def _get_convert_map(opset): "CumSum": CumSum.get_converter(opset), "Unique": Unique.get_converter(opset), "Einsum": Einsum.get_converter(opset), + "Trilu": Trilu.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d430eaccbdc3..dfeb8bdb0966 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3281,6 +3281,11 @@ def convert_matrix_set_diag(self, op): input_expr = self.get_tensor_expr(input_tensors[0]) diagonal_expr = self.get_tensor_expr(input_tensors[1]) + diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1])) + input_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) + if len(diag_shape) == len(input_shape) - 1: + diag_shape = np.insert(diag_shape, len(diag_shape) - 1, 1) + diagonal_expr = _op.reshape(diagonal_expr, diag_shape) out = _op.matrix_set_diag(input_expr, diagonal_expr) return out @@ -3300,14 +3305,15 @@ def convert_matrix_diag(self, op): ), "TFLite MATRIX_DIAG requires diagonal and output tensors' \ scale and zero points to be equal" + # Tflite's output tensor for matrix_diag has rank k+1 shape = to_int_list(self.get_tensor_shape(diagonal)) - shape = np.append(shape, shape[-1]) + # Diagonal's tensor has rank k. Therefore we remove the last dimension [:-1]. + diag_shape = np.insert(shape[:-1], len(shape[:-1]) - 1, 1).astype(np.int32) dtype = self.get_tensor_type_str(diagonal.tensor.Type()) - input_expr = _op.zeros(tuple(shape), dtype) diagonal_expr = self.get_tensor_expr(diagonal) - out = _op.matrix_set_diag(input_expr, diagonal_expr) + out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape)) return out def convert_densify(self, op): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 27dfefbb7890..f782451b277c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,6 +18,7 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" +import numpy as np from ...tir import expr as _expr from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make @@ -1409,6 +1410,11 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): k_one = k k_two = k + if not isinstance(k_one, Expr): + k_one = const(np.asarray([k_one], dtype=np.int64)) + if not isinstance(k_two, Expr): + k_two = const(np.asarray([k_two], dtype=np.int64)) + super_diag_right_align = align[:5] == "RIGHT" sub_diag_right_align = align[-5:] == "RIGHT" diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d225d93fe394..3df6f35d08d1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3812,7 +3812,7 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs); bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [input, diagonal, result] - ICHECK_EQ(types.size(), 3); + ICHECK_EQ(types.size(), 5); const auto* input = types[0].as(); ICHECK(input); @@ -3820,30 +3820,18 @@ bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& att const auto* diagonal = types[1].as(); ICHECK(diagonal); - const auto param = attrs.as(); - ICHECK_GE(param->k2, param->k1); - - int d_ndims = diagonal->shape.size(); - int i_ndims = input->shape.size(); + const auto* k1 = types[2].as(); + ICHECK(k1); - reporter->Assert(input->shape[i_ndims - 2] > -param->k1); - reporter->Assert(input->shape[i_ndims - 1] > param->k2); + const auto* k2 = types[3].as(); + ICHECK(k2); + int d_ndims = diagonal->shape.size(); for (int i = 0; i < d_ndims - 2; i++) { reporter->AssertEQ(input->shape[i], diagonal->shape[i]); } - if (param->k1 != param->k2) { - reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1); - } else if (d_ndims >= 2) { - reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]); - } - auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <= - input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0), - input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0), - input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0)); - reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len); - reporter->Assign(types[2], TensorType(input->shape, input->dtype)); + reporter->Assign(types[4], TensorType(input->shape, input->dtype)); return true; } @@ -3851,20 +3839,18 @@ Array MatrixSetDiagCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); - return Array{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2, + return Array{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3], param->super_diag_right_align, param->sub_diag_right_align)}; } -Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align, +Expr MakeMatrixSetDiag(Expr input, Expr diagonal, Expr k1, Expr k2, bool super_diag_right_align, bool sub_diag_right_align) { auto attrs = make_object(); - attrs->k1 = k1; - attrs->k2 = k2; attrs->super_diag_right_align = super_diag_right_align; attrs->sub_diag_right_align = sub_diag_right_align; static const Op& op = Op::Get("matrix_set_diag"); - return Call(op, {input, diagonal}, Attrs(attrs), {}); + return Call(op, {input, diagonal, k1, k2}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag); @@ -3880,9 +3866,11 @@ RELAY_REGISTER_OP("matrix_set_diag") **sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded). )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(2) + .set_num_inputs(4) .add_argument("input", "Tensor", "Input Tensor.") .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.") + .add_argument("k1", "Tensor", "ILower limit (included) of the range of diagonals.") + .add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.") .set_support_level(10) .add_type_rel("MatrixSetDiag", MatrixSetDiagRel) .set_attr("FTVMCompute", MatrixSetDiagCompute) diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 56e799f52563..d6dc42237bd3 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -213,11 +213,10 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) { - int k1 = args[2]; - int k2 = args[3]; bool super_diag_right_align = args[4]; bool sub_diag_right_align = args[5]; - *rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align); + *rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align, + sub_diag_right_align); }); TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8805d4d79c27..0e72811f53f6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5114,24 +5114,6 @@ def verify_eyelike(indata): "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", - "test_tril", - "test_tril_pos", - "test_tril_square", - "test_tril_square_neg", - "test_tril_neg", - "test_tril_one_row_neg", - "test_tril_out_neg", - "test_tril_out_pos", - "test_tril_zero", - "test_triu", - "test_triu_one_row", - "test_triu_out_neg_out", - "test_triu_out_pos", - "test_triu_neg", - "test_triu_pos", - "test_triu_square", - "test_triu_square_neg", - "test_triu_zero", # These unsqueeze tests work, but take 2+ hrs to run "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0486ef40017b..41b164423c6e 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -596,7 +596,13 @@ def test_matrix_set_diag(): def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = relay.var("input", relay.TensorType(input_shape, dtype)) diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype)) - out = relay.matrix_set_diag(input, diagonal, k, align) + out = None + if len(diagonal_shape) == len(input_shape) - 1: + new_shape = list(diagonal_shape) + new_shape.insert(-1, 1) + out = relay.matrix_set_diag(input, relay.reshape(diagonal, new_shape), k, align) + else: + out = relay.matrix_set_diag(input, diagonal, k, align) in_type = run_infer_type(input) out_type = run_infer_type(out) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 730d22cba16a..64ca1488c188 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -752,21 +752,38 @@ def check_device(target, dev): def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) - matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align) + k1 = te.placeholder(shape=(1,), name="k1", dtype="int64") + k2 = te.placeholder(shape=(1,), name="k2", dtype="int64") + matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align) + + k_one, k_two = None, None + if isinstance(k, (tuple, list)): + k_one = k[0] + if len(k) >= 2: + k_two = k[1] + else: + k_two = k[0] + else: + k_one = k + k_two = k def check_device(target, dev): dev = tvm.device(target, 0) print("Running on target: %s" % target) with tvm.target.Target(target): s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result) - fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], target, name="matrix_set_diag") + fn = tvm.build( + s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag" + ) input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) input_nd = tvm.nd.array(input_npy, dev) diagonal_nd = tvm.nd.array(diagonal_npy, dev) + k1_nd = tvm.nd.array(np.asarray([k_one]), dev) + k2_nd = tvm.nd.array(np.asarray([k_two]), dev) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev) - fn(input_nd, diagonal_nd, out_nd) + fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd) out_topi = out_nd.numpy() tvm.testing.assert_allclose(out_topi, out_npy) @@ -861,10 +878,10 @@ def test_reinterpret(): (1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) ) @@ -1240,9 +1257,6 @@ def test_sparse_to_dense(): @tvm.testing.uses_gpu def test_matrix_set_diag(): for dtype in ["float32", "int32"]: - verify_matrix_set_diag((2, 2), (2,), dtype) - verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) - verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1) verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT") verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT") verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT") From eff297b12ea39016a1f674fa64e0d9ae441093dc Mon Sep 17 00:00:00 2001 From: Brunno Goldstein Date: Fri, 1 Apr 2022 15:59:22 -0300 Subject: [PATCH 5/6] Fix lint over test_topi_transform.py. --- tests/python/topi/python/test_topi_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 64ca1488c188..80cfb02790e3 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -878,10 +878,10 @@ def test_reinterpret(): (1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) ) From 0f1535e8255b1a01ec649afc979e1ced580609a8 Mon Sep 17 00:00:00 2001 From: Brunno Goldstein Date: Tue, 5 Apr 2022 12:53:22 -0700 Subject: [PATCH 6/6] - Josh's comments: -- Fixed typo in transform.cc; -- Adding comments to verify_matrix_set_diag on test_topi_transform.py; -- Rollback TOPI tests set in test_matrix_set_diag; -- Rollback TFLite frontend convert_matrix_diag method; --- For some reason (maybe some package version), removing the last dimension was necessary (on Macbook M1); --- When moving to Linux, the old version was running smoothly and test passed; - Andrew's comments: -- Small change into matrix_set_diag of include/tvm/topi/transform.h; -- Fixed comment in relay/op/tensor/transform.cc --- include/tvm/topi/transform.h | 3 +- python/tvm/relay/frontend/tflite.py | 5 ++- src/relay/op/tensor/transform.cc | 4 +-- .../python/topi/python/test_topi_transform.py | 36 +++++++++++++++++-- 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 9c69b4500467..6be335670277 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1885,8 +1885,7 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, const : 0); return k; }; - auto get_k = [&]() { return if_then_else(k1(0) == k2(0), k1(0), multi_diagonals()); }; - k = get_k(); + k = if_then_else(k1(0) == k2(0), k1(0), multi_diagonals()); diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + offset); return diagonal(diagonal_indices); diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index dfeb8bdb0966..a4f7a2925964 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3305,11 +3305,10 @@ def convert_matrix_diag(self, op): ), "TFLite MATRIX_DIAG requires diagonal and output tensors' \ scale and zero points to be equal" - # Tflite's output tensor for matrix_diag has rank k+1 shape = to_int_list(self.get_tensor_shape(diagonal)) - # Diagonal's tensor has rank k. Therefore we remove the last dimension [:-1]. - diag_shape = np.insert(shape[:-1], len(shape[:-1]) - 1, 1).astype(np.int32) + diag_shape = np.insert(shape, len(shape) - 1, 1).astype(np.int32) dtype = self.get_tensor_type_str(diagonal.tensor.Type()) + shape = np.append(shape, shape[-1]).astype(np.int32) input_expr = _op.zeros(tuple(shape), dtype) diagonal_expr = self.get_tensor_expr(diagonal) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3df6f35d08d1..54f3bd3f5f18 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3811,7 +3811,7 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs); bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // `types` contains: [input, diagonal, result] + // `types` contains: [input, diagonal, k1, k2, result] ICHECK_EQ(types.size(), 5); const auto* input = types[0].as(); @@ -3869,7 +3869,7 @@ RELAY_REGISTER_OP("matrix_set_diag") .set_num_inputs(4) .add_argument("input", "Tensor", "Input Tensor.") .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.") - .add_argument("k1", "Tensor", "ILower limit (included) of the range of diagonals.") + .add_argument("k1", "Tensor", "Lower limit (included) of the range of diagonals.") .add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.") .set_support_level(10) .add_type_rel("MatrixSetDiag", MatrixSetDiagRel) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 80cfb02790e3..78bc50ace211 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -750,12 +750,20 @@ def check_device(target, dev): def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): + # input matrix that contains diagonals to be replaced input = te.placeholder(shape=input_shape, name="input", dtype=dtype) + # diagonal values to be placed as new diagonal values of input matrix diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) + # diagonals offsets + # k1 and k2 define the lower and upper limits of diagonals to be set + # where k*=0 means main diagonal, k*< 0 sub-diagonal, and k*> 0 super-diagonal + # when k is not an tuple or list, k1 will be equal to k2, meaning that only one diagonal will be replaced. k1 = te.placeholder(shape=(1,), name="k1", dtype="int64") + # k2 defines the upper limit diagonal to be set k2 = te.placeholder(shape=(1,), name="k2", dtype="int64") matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align) + # k can be an integer or a pair of integers representing the lower and upper limits of a matrix band; k_one, k_two = None, None if isinstance(k, (tuple, list)): k_one = k[0] @@ -767,6 +775,14 @@ def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT k_one = k k_two = k + # Generate random data for input matrix + input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) + # Generate random data for diagonal (single or multiple diagionals) + diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) + # Run numpy test for matrix_set_diag with random data + # output will be saved to compare with TOPI version of matrix_set_diag + out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) + def check_device(target, dev): dev = tvm.device(target, 0) print("Running on target: %s" % target) @@ -775,16 +791,27 @@ def check_device(target, dev): fn = tvm.build( s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag" ) - input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) - diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) - out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) + + # Convert numpy input data to TVM ND array input_nd = tvm.nd.array(input_npy, dev) + + # Convert numpy diagonal data to TVM ND array diagonal_nd = tvm.nd.array(diagonal_npy, dev) + + # Convert k1 and k2 to numpy array and then to TVM ND array k1_nd = tvm.nd.array(np.asarray([k_one]), dev) k2_nd = tvm.nd.array(np.asarray([k_two]), dev) + + # Convert k1 and k2 to numpy array and then to TVM ND array out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev) + + # Run TOPI test for matrix_set_diag with random data fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd) + + # Convert TOPI output to numpy out_topi = out_nd.numpy() + + # Check if Numpy version matches TOPI one tvm.testing.assert_allclose(out_topi, out_npy) for target, dev in tvm.testing.enabled_targets(): @@ -1257,6 +1284,9 @@ def test_sparse_to_dense(): @tvm.testing.uses_gpu def test_matrix_set_diag(): for dtype in ["float32", "int32"]: + verify_matrix_set_diag((2, 2), (2,), dtype) + verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) + verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1) verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT") verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT") verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT")