From e69cf62f5da123ad61acca79b21019d793be4757 Mon Sep 17 00:00:00 2001 From: sxhu Date: Wed, 20 Oct 2021 15:37:20 +0800 Subject: [PATCH 1/6] [Relay/TOPI][ONNX/TFLite] Refactor MATRIX_SET_DIAG Operator for Relay/TOPI to support ONNX Trilu operator --- include/tvm/relay/attrs/transform.h | 4 -- include/tvm/topi/transform.h | 20 ++++----- python/tvm/relay/frontend/onnx.py | 30 +++++++++++++ python/tvm/relay/frontend/tflite.py | 14 +++++-- python/tvm/relay/op/transform.py | 6 +++ src/relay/op/tensor/transform.cc | 39 +++++++---------- src/topi/transform.cc | 4 +- tests/python/frontend/onnx/test_forward.py | 42 +++++++++++++++++++ tests/python/relay/test_op_level10.py | 8 +++- .../python/topi/python/test_topi_transform.py | 24 ++++++++--- 10 files changed, 139 insertions(+), 52 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 0e04b0936f24..b058caf03b70 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -456,14 +456,10 @@ struct OneHotAttrs : public tvm::AttrsNode { /*! \brief Attributes used in matrix_set_diag operator */ struct MatrixSetDiagAttrs : public tvm::AttrsNode { - int k1; - int k2; bool super_diag_right_align; bool sub_diag_right_align; TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") { - TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals."); - TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals."); TVM_ATTR_FIELD(super_diag_right_align) .set_default(true) .describe("Bool, true iff super-diagonal is right aligned (left-padded)."); diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 3df9caf55d5c..3c44e96dcfee 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1759,14 +1759,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Arrayshape.size() - 1; - bool only_one_diagonal = k1 == k2; - return compute( input->shape, [&](const Array& iter_vars) { @@ -1776,12 +1775,10 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k for (size_t i = 0; i < ndim - 1; i++) { diagonal_indices.push_back(iter_vars[i]); } - if (only_one_diagonal) { - k = k1; - } else { + auto multi_diagonals = [&]() { // Determining which diagonal/sub-diagonal/super-diagonal it is k = iter_vars[ndim] - iter_vars[ndim - 1]; - diagonal_indices.push_back(k2 - k); + diagonal_indices.push_back(k2(0) - k); // Calculating the offset in diagonal tensor for this diagonal auto get_offset = [&](PrimExpr M, PrimExpr N) { @@ -1794,13 +1791,16 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k : 0, sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k) : 0); - } + return k; + }; + auto get_k = [&]() { return if_then_else(k1(0) == k2(0), k1(0), multi_diagonals()); }; + k = get_k(); diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + offset); return diagonal(diagonal_indices); }; - return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1, - if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2, + return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0), + if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0), get_diag(), input(iter_vars)), input(iter_vars)); }, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5c112c7dfce0..ba3816f4ab03 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4226,6 +4226,35 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Trilu(OnnxOpConverter): + """Operator converter for Trilu""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + upper = attr.get("upper") + input_shape = shape_of(inputs[0]) + input_dims = infer_shape(input_shape)[0] + data_type = infer_type(inputs[0]).checked_type.dtype + k_tensor = relay.const(np.asarray([0], dtype=np.int64)) + if len(inputs) == 2: + k_tensor = inputs[1] + + diag_input = relay.zeros(fold_constant(shape_of(inputs[0])), dtype=data_type) + + if upper == 0: + k1 = relay.add(k_tensor, relay.const(1, dtype="int64")) + k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32")) + k2 = relay.expand_dims(k2, axis=0) + return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2)) + else: + k1 = relay.take(input_shape, relay.const(input_dims-2, dtype="int32")) + k1 = relay.multiply(k1, relay.const(-1, dtype="int64")) + k1 = relay.subtract(k1, relay.const(1, dtype="int64")) + k1 = relay.expand_dims(k1, axis=0) + k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64")) + return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2)) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -4425,6 +4454,7 @@ def _get_convert_map(opset): "Adagrad": Adagrad.get_converter(opset), "Adam": Adam.get_converter(opset), "Momentum": Momentum.get_converter(opset), + "Trilu": Trilu.get_converter(opset), } diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3688ff5ff4e5..184b69a05a95 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3293,6 +3293,11 @@ def convert_matrix_set_diag(self, op): input_expr = self.get_tensor_expr(input_tensors[0]) diagonal_expr = self.get_tensor_expr(input_tensors[1]) + diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1])) + input_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) + if len(diag_shape) == len(input_shape) - 1: + diag_shape = np.insert(diag_shape, len(diag_shape)-1, 1) + diagonal_expr = _op.reshape(diagonal_expr, diag_shape) out = _op.matrix_set_diag(input_expr, diagonal_expr) return out @@ -3313,13 +3318,16 @@ def convert_matrix_diag(self, op): scale and zero points to be equal" shape = to_int_list(self.get_tensor_shape(diagonal)) - shape = np.append(shape, shape[-1]) + shape_copy = np.copy(shape) + diag_shape = np.insert(shape, len(shape)-1, 1).astype(np.int32) + + shape = np.append(shape_copy, shape[-1]).astype(np.int32) dtype = self.get_tensor_type_str(diagonal.tensor.Type()) - + input_expr = _op.zeros(tuple(shape), dtype) diagonal_expr = self.get_tensor_expr(diagonal) - out = _op.matrix_set_diag(input_expr, diagonal_expr) + out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape)) return out def convert_densify(self, op): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 234e76b11813..4419dd360581 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -23,6 +23,7 @@ from . import _make from .dyn import _make as _dyn_make from .tensor import shape_of +import numpy as np def cast(data, dtype): @@ -1352,6 +1353,11 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): k_one = k k_two = k + if not isinstance(k_one, Expr): + k_one = const(np.asarray([k_one], dtype=np.int64)) + if not isinstance(k_two, Expr): + k_two = const(np.asarray([k_two], dtype=np.int64)) + super_diag_right_align = align[:5] == "RIGHT" sub_diag_right_align = align[-5:] == "RIGHT" diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 90a0e3150573..836008d36026 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3708,7 +3708,7 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs); bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [input, diagonal, result] - ICHECK_EQ(types.size(), 3); + ICHECK_EQ(types.size(), 5); const auto* input = types[0].as(); ICHECK(input); @@ -3716,30 +3716,19 @@ bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& att const auto* diagonal = types[1].as(); ICHECK(diagonal); - const auto param = attrs.as(); - ICHECK_GE(param->k2, param->k1); - - int d_ndims = diagonal->shape.size(); - int i_ndims = input->shape.size(); + const auto* k1 = types[2].as(); + ICHECK(k1); - reporter->Assert(input->shape[i_ndims - 2] > -param->k1); - reporter->Assert(input->shape[i_ndims - 1] > param->k2); + const auto* k2 = types[3].as(); + ICHECK(k2); + int d_ndims = diagonal->shape.size(); + for (int i = 0; i < d_ndims - 2; i++) { reporter->AssertEQ(input->shape[i], diagonal->shape[i]); } - if (param->k1 != param->k2) { - reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1); - } else if (d_ndims >= 2) { - reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]); - } - auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <= - input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0), - input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0), - input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0)); - reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len); - reporter->Assign(types[2], TensorType(input->shape, input->dtype)); + reporter->Assign(types[4], TensorType(input->shape, input->dtype)); return true; } @@ -3747,20 +3736,18 @@ Array MatrixSetDiagCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); - return Array{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2, + return Array{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3], param->super_diag_right_align, param->sub_diag_right_align)}; } -Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align, +Expr MakeMatrixSetDiag(Expr input, Expr diagonal, Expr k1, Expr k2, bool super_diag_right_align, bool sub_diag_right_align) { auto attrs = make_object(); - attrs->k1 = k1; - attrs->k2 = k2; attrs->super_diag_right_align = super_diag_right_align; attrs->sub_diag_right_align = sub_diag_right_align; static const Op& op = Op::Get("matrix_set_diag"); - return Call(op, {input, diagonal}, Attrs(attrs), {}); + return Call(op, {input, diagonal, k1, k2}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag); @@ -3776,9 +3763,11 @@ RELAY_REGISTER_OP("matrix_set_diag") **sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded). )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(2) + .set_num_inputs(4) .add_argument("input", "Tensor", "Input Tensor.") .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.") + .add_argument("k1", "Tensor", "ILower limit (included) of the range of diagonals.") + .add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.") .set_support_level(10) .add_type_rel("MatrixSetDiag", MatrixSetDiagRel) .set_attr("FTVMCompute", MatrixSetDiagCompute) diff --git a/src/topi/transform.cc b/src/topi/transform.cc index db54d5a99a91..c754f778462f 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -209,11 +209,9 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) { - int k1 = args[2]; - int k2 = args[3]; bool super_diag_right_align = args[4]; bool sub_diag_right_align = args[5]; - *rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align); + *rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align, sub_diag_right_align); }); TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dd1c77330986..d69e3e9220b5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5773,6 +5773,47 @@ def repeat(N, D): repeat(2, D), ) +@tvm.testing.parametrize_targets +def test_trilu(target, dev): + def verify_trilu(in_shape, k, upper): + trilu_node = helper.make_node('Trilu', inputs=["x", "k"], outputs=["out"], upper=upper) + graph = helper.make_graph( + [trilu_node], + "trilu_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("k", TensorProto.INT64, list((1,))), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(in_shape))], + ) + + model = helper.make_model(graph, producer_name="trilu_test") + input_array = np.random.rand(*in_shape).astype("float32") + verify_with_ort_with_inputs(model, [input_array, np.asarray(k)], target=target, dev=dev, use_vm=True) + + in_shape = (4, 5) + verify_trilu(in_shape, [4], 0) + verify_trilu(in_shape, [5], 0) + verify_trilu(in_shape, [5], 1) + verify_trilu(in_shape, [-1], 0) + verify_trilu(in_shape, [-1], 1) + verify_trilu(in_shape, [-7], 0) + verify_trilu(in_shape, [-7], 1) + verify_trilu(in_shape, [6], 0) + verify_trilu(in_shape, [6], 1) + + in_shape = (3, 1, 5) + verify_trilu(in_shape, [0], 0) + verify_trilu(in_shape, [1], 1) + verify_trilu(in_shape, [6], 0) + verify_trilu(in_shape, [6], 1) + + in_shape = (3, 5, 5) + verify_trilu(in_shape, [0], 0) + verify_trilu(in_shape, [0], 1) + verify_trilu(in_shape, [-1], 0) + verify_trilu(in_shape, [-1], 1) + if __name__ == "__main__": test_flatten() @@ -5864,3 +5905,4 @@ def repeat(N, D): test_convinteger() test_batch_matmul() test_global_lppool() + test_trilu() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 22b3983df1c3..dab164bc9585 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -543,7 +543,13 @@ def test_matrix_set_diag(): def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = relay.var("input", relay.TensorType(input_shape, dtype)) diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype)) - out = relay.matrix_set_diag(input, diagonal, k, align) + out = None + if len(diagonal_shape) == len(input_shape) - 1: + new_shape = list(diagonal_shape) + new_shape.insert(-1, 1) + out = relay.matrix_set_diag(input, relay.reshape(diagonal, new_shape), k, align) + else: + out = relay.matrix_set_diag(input, diagonal, k, align) in_type = run_infer_type(input) out_type = run_infer_type(out) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 42d2463b8952..19cb0dbc3cbd 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -753,21 +753,36 @@ def check_device(target, dev): def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) - matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align) + k1 = te.placeholder(shape=(1,), name="k1", dtype="int64") + k2 = te.placeholder(shape=(1,), name="k2", dtype="int64") + matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align) + + k_one, k_two = None, None + if isinstance(k, (tuple, list)): + k_one = k[0] + if len(k) >= 2: + k_two = k[1] + else: + k_two = k[0] + else: + k_one = k + k_two = k def check_device(target, dev): dev = tvm.device(target, 0) print("Running on target: %s" % target) with tvm.target.Target(target): s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result) - fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], target, name="matrix_set_diag") + fn = tvm.build(s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag") input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) input_nd = tvm.nd.array(input_npy, dev) diagonal_nd = tvm.nd.array(diagonal_npy, dev) + k1_nd = tvm.nd.array(np.asarray([k_one]), dev) + k2_nd = tvm.nd.array(np.asarray([k_two]), dev) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev) - fn(input_nd, diagonal_nd, out_nd) + fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd) out_topi = out_nd.numpy() tvm.testing.assert_allclose(out_topi, out_npy) @@ -1235,9 +1250,6 @@ def test_sparse_to_dense(): @tvm.testing.uses_gpu def test_matrix_set_diag(): for dtype in ["float32", "int32"]: - verify_matrix_set_diag((2, 2), (2,), dtype) - verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) - verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1) verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT") verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT") verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT") From 59ae654d7cd25262f878d8b6b7b8ff2e82e9ffc5 Mon Sep 17 00:00:00 2001 From: sxhu Date: Fri, 22 Oct 2021 18:25:18 +0800 Subject: [PATCH 2/6] fix lint and optimize --- include/tvm/topi/transform.h | 6 +++--- python/tvm/relay/frontend/onnx.py | 4 ++-- src/relay/op/tensor/transform.cc | 3 +-- src/topi/transform.cc | 3 ++- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 3c44e96dcfee..0b87b8e4e2f5 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1759,9 +1759,9 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Arrayshape.size() - 1; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ba3816f4ab03..1573ace19d33 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4230,8 +4230,8 @@ class Trilu(OnnxOpConverter): """Operator converter for Trilu""" @classmethod - def _impl_v1(cls, inputs, attr, params): - upper = attr.get("upper") + def _impl_v14(cls, inputs, attr, params): + upper = attr.get("upper", 1) input_shape = shape_of(inputs[0]) input_dims = infer_shape(input_shape)[0] data_type = infer_type(inputs[0]).checked_type.dtype diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 836008d36026..afb9456a10ed 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3722,8 +3722,7 @@ bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& att const auto* k2 = types[3].as(); ICHECK(k2); - int d_ndims = diagonal->shape.size(); - + int d_ndims = diagonal->shape.size(); for (int i = 0; i < d_ndims - 2; i++) { reporter->AssertEQ(input->shape[i], diagonal->shape[i]); } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index c754f778462f..7f704f0f5ac3 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -211,7 +211,8 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) { bool super_diag_right_align = args[4]; bool sub_diag_right_align = args[5]; - *rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align, sub_diag_right_align); + *rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align, + sub_diag_right_align); }); TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) { From 139f815febfc30c780d3870a4c995ffe6e7ed95e Mon Sep 17 00:00:00 2001 From: sxhu Date: Mon, 25 Oct 2021 15:46:14 +0800 Subject: [PATCH 3/6] fix lint --- src/relay/op/tensor/transform.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index afb9456a10ed..03e70c8bf75a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3722,7 +3722,7 @@ bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& att const auto* k2 = types[3].as(); ICHECK(k2); - int d_ndims = diagonal->shape.size(); + int d_ndims = diagonal->shape.size(); for (int i = 0; i < d_ndims - 2; i++) { reporter->AssertEQ(input->shape[i], diagonal->shape[i]); } From ae010893690f7be77f5383f3997d81e6ae02b6df Mon Sep 17 00:00:00 2001 From: sxhu Date: Tue, 26 Oct 2021 09:41:26 +0800 Subject: [PATCH 4/6] fix lint --- python/tvm/relay/frontend/onnx.py | 8 ++++---- python/tvm/relay/frontend/tflite.py | 6 ++---- python/tvm/relay/op/transform.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1573ace19d33..c50298802ee0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4240,19 +4240,19 @@ def _impl_v14(cls, inputs, attr, params): k_tensor = inputs[1] diag_input = relay.zeros(fold_constant(shape_of(inputs[0])), dtype=data_type) - + k1, k2 = None, None if upper == 0: k1 = relay.add(k_tensor, relay.const(1, dtype="int64")) k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32")) k2 = relay.expand_dims(k2, axis=0) - return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2)) else: - k1 = relay.take(input_shape, relay.const(input_dims-2, dtype="int32")) + k1 = relay.take(input_shape, relay.const(input_dims - 2, dtype="int32")) k1 = relay.multiply(k1, relay.const(-1, dtype="int64")) k1 = relay.subtract(k1, relay.const(1, dtype="int64")) k1 = relay.expand_dims(k1, axis=0) k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64")) - return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2)) + + return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2)) # compatible operators that do NOT require any conversion. diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 184b69a05a95..d3576905cd74 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3319,11 +3319,9 @@ def convert_matrix_diag(self, op): shape = to_int_list(self.get_tensor_shape(diagonal)) shape_copy = np.copy(shape) - diag_shape = np.insert(shape, len(shape)-1, 1).astype(np.int32) - + diag_shape = np.insert(shape, len(shape)-1, 1).astype(np.int32) shape = np.append(shape_copy, shape[-1]).astype(np.int32) - dtype = self.get_tensor_type_str(diagonal.tensor.Type()) - + dtype = self.get_tensor_type_str(diagonal.tensor.Type()) input_expr = _op.zeros(tuple(shape), dtype) diagonal_expr = self.get_tensor_expr(diagonal) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 4419dd360581..2a6844c82f04 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,12 +18,12 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" +import numpy as np from ...tir import expr as _expr from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make from .dyn import _make as _dyn_make from .tensor import shape_of -import numpy as np def cast(data, dtype): From 26bb2400a753a3ebe9f2c40d46b7a10ae3f894af Mon Sep 17 00:00:00 2001 From: sxhu Date: Tue, 26 Oct 2021 12:23:06 +0800 Subject: [PATCH 5/6] fix lint --- python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/relay/frontend/tflite.py | 6 +++--- tests/python/frontend/onnx/test_forward.py | 7 +++++-- tests/python/topi/python/test_topi_transform.py | 4 +++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c50298802ee0..f5e1c5d7d34e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4228,7 +4228,7 @@ def _impl_v1(cls, inputs, attr, params): class Trilu(OnnxOpConverter): """Operator converter for Trilu""" - + @classmethod def _impl_v14(cls, inputs, attr, params): upper = attr.get("upper", 1) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d3576905cd74..8011c21d47ad 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3296,7 +3296,7 @@ def convert_matrix_set_diag(self, op): diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1])) input_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) if len(diag_shape) == len(input_shape) - 1: - diag_shape = np.insert(diag_shape, len(diag_shape)-1, 1) + diag_shape = np.insert(diag_shape, len(diag_shape) - 1, 1) diagonal_expr = _op.reshape(diagonal_expr, diag_shape) out = _op.matrix_set_diag(input_expr, diagonal_expr) @@ -3319,9 +3319,9 @@ def convert_matrix_diag(self, op): shape = to_int_list(self.get_tensor_shape(diagonal)) shape_copy = np.copy(shape) - diag_shape = np.insert(shape, len(shape)-1, 1).astype(np.int32) + diag_shape = np.insert(shape, len(shape) - 1, 1).astype(np.int32) shape = np.append(shape_copy, shape[-1]).astype(np.int32) - dtype = self.get_tensor_type_str(diagonal.tensor.Type()) + dtype = self.get_tensor_type_str(diagonal.tensor.Type()) input_expr = _op.zeros(tuple(shape), dtype) diagonal_expr = self.get_tensor_expr(diagonal) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d69e3e9220b5..72117f532bb6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5773,10 +5773,11 @@ def repeat(N, D): repeat(2, D), ) + @tvm.testing.parametrize_targets def test_trilu(target, dev): def verify_trilu(in_shape, k, upper): - trilu_node = helper.make_node('Trilu', inputs=["x", "k"], outputs=["out"], upper=upper) + trilu_node = helper.make_node("Trilu", inputs=["x", "k"], outputs=["out"], upper=upper) graph = helper.make_graph( [trilu_node], "trilu_test", @@ -5789,7 +5790,9 @@ def verify_trilu(in_shape, k, upper): model = helper.make_model(graph, producer_name="trilu_test") input_array = np.random.rand(*in_shape).astype("float32") - verify_with_ort_with_inputs(model, [input_array, np.asarray(k)], target=target, dev=dev, use_vm=True) + verify_with_ort_with_inputs( + model, [input_array, np.asarray(k)], target=target, dev=dev, use_vm=True + ) in_shape = (4, 5) verify_trilu(in_shape, [4], 0) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 19cb0dbc3cbd..97ccf9457521 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -773,7 +773,9 @@ def check_device(target, dev): print("Running on target: %s" % target) with tvm.target.Target(target): s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result) - fn = tvm.build(s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag") + fn = tvm.build( + s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag" + ) input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) From ab40d8053cad8027eb54d4748f1def2ec4e6f58d Mon Sep 17 00:00:00 2001 From: sxhu Date: Wed, 5 Jan 2022 10:24:58 +0800 Subject: [PATCH 6/6] fix lint --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7262a5838d43..8ae4bf2f0083 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -6014,7 +6014,7 @@ def verify_trilu(in_shape, k, upper): verify_trilu(in_shape, [-1], 0) verify_trilu(in_shape, [-1], 1) - + def test_scan(target, dev): def verify_scan( input_shapes,