From 89f0e52b6744928c0d95e58f1bddb3d9e44bc057 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 18 Jul 2022 14:55:13 -0700 Subject: [PATCH 1/7] Added topi trilu implementation --- python/tvm/topi/transform.py | 53 +++++++++++++++++++ .../python/topi/python/test_topi_transform.py | 42 ++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index d99d6772b0cd..40a98d5dd557 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1001,3 +1001,56 @@ def sliding_window(data, axis, window_shape, strides): The resulting tensor. """ return cpp.sliding_window(data, axis, window_shape, strides) + + +def trilu(x, upper, k): + """ + Given a 2-D matrix or batches of 2-D matrices, returns the + upper or lower triangular part of the tensor. + + Parameters + ---------- + x: tvm.te.Tensor + The tensor that trilu will be applied to. Must be either + a 2D matrix or a tensor of batches of 2D matrices. + + upper: bool + If True, only upper triangular values of input are kept, + if False, the lower triangular values are kept. + + k: int + The number of diagonals above or below the main diagonal + to exclude or include. + + Returns + ------- + ret : relay.Expr + The new tensor with appropriate diagonals set to zero. + + Examples + -------- + .. code-block:: python + + x = [[0, 1, 2], + [3, 4, 5], + [6, 7, 8]] + + relay.trilu(x, True, 0) = + [[0, 1, 2], + [0, 4, 5], + [0, 0, 8]] + """ + # Check either above or below diagonal depending on upper. + check_op = tvm.tir.GE + if upper: + check_op = tvm.tir.LE + + def _apply_trilu(*indices): + row_index = indices[-2] + col_index = indices[-1] + other_indices = indices[:-2] + check_position = check_op(row_index, col_index - k) + value = x(*other_indices, row_index, col_index) + return tvm.tir.Select(check_position, value, tvm.tir.const(0, x.dtype)) + + return te.compute(x.shape, _apply_trilu, name="trilu") diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 180f267650cc..c4493477abb1 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -812,6 +812,30 @@ def check_device(target, dev): check_device(target, dev) +def verify_trilu(input_shape, upper, k=0): + x = te.placeholder(shape=input_shape, name="x", dtype="float32") + trilu_result = topi.transform.trilu(x, upper, k) + + def check_device(target, dev): + print("Running on target: %s" % target) + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(trilu_result) + fn = tvm.build(s, [x, trilu_result], target, name="trilu") + x_npy = np.random.normal(size=input_shape).astype(x.dtype) + if upper: + out_npy = np.triu(x_npy, k) + else: + out_npy = np.tril(x_npy, k) + x_nd = tvm.nd.array(x_npy, dev) + out_nd = tvm.nd.array(np.empty(x_npy.shape).astype(trilu_result.dtype), dev) + fn(x_nd, out_nd) + out_topi = out_nd.numpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + for target, dev in tvm.testing.enabled_targets(): + check_device(target, dev) + + @tvm.testing.uses_gpu def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) @@ -861,10 +885,10 @@ def test_reinterpret(): (1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) ) @@ -1256,6 +1280,19 @@ def test_adv_index(): verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype) +@tvm.testing.uses_gpu +def test_trilu(): + # Test upper and lower triangle + verify_trilu((3, 3), True, 0) + verify_trilu((3, 3), False, 0) + # Test larger matrices with offset. + verify_trilu((6, 6), True, 1) + verify_trilu((6, 6), False, 2) + verify_trilu((6, 6), False, -2) + # Test batch size + verify_trilu((8, 6, 6), False, -2) + + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -1283,3 +1320,4 @@ def test_adv_index(): test_sparse_to_dense() test_matrix_set_diag() test_adv_index() + test_trilu() From 3f78abf5b7ece38070500354aad0d2a2ca4f6dd3 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 18 Jul 2022 16:18:50 -0700 Subject: [PATCH 2/7] Implemented and tested full Trilu op. --- include/tvm/relay/attrs/transform.h | 9 ++++ python/tvm/relay/frontend/onnx.py | 15 ++++++ python/tvm/relay/op/_transform.py | 4 ++ python/tvm/relay/op/op_attrs.py | 5 ++ python/tvm/relay/op/strategy/generic.py | 28 +++++++++++ python/tvm/relay/op/transform.py | 43 ++++++++++++++++ python/tvm/topi/transform.py | 21 +++++--- src/relay/op/tensor/transform.cc | 50 +++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 16 ------ tests/python/relay/test_op_level3.py | 29 +++++++++++ .../python/topi/python/test_topi_transform.py | 6 +-- 11 files changed, 199 insertions(+), 27 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b9f8c6e1e847..2741d68eec14 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode { } }; // struct StftAttrs +struct TriluAttrs : public tvm::AttrsNode { + bool upper; + + TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") { + TVM_ATTR_FIELD(upper).set_default(true).describe( + "Whether to keep the upper or lower half of the diagonal."); + } +}; // struct TriluAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3b5bf9acfa42..e78e65dc4ec6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4685,6 +4685,20 @@ def _impl_v12(cls, inputs, attr, params): return _op.einsum(inputs, equation) +class Trilu(OnnxOpConverter): + """Operator converter for Trilu""" + + @classmethod + def _impl_v14(cls, inputs, attr, params): + upper = attr.get("upper", True) + if len(inputs) == 2: + data, k = inputs + else: + data = inputs[0] + k = 0 + return _op.trilu(data, k, upper) + + class RandomNormal(OnnxOpConverter): """Operator converter for random_normal""" @@ -5345,6 +5359,7 @@ def _get_convert_map(opset): "CumSum": CumSum.get_converter(opset), "Unique": Unique.get_converter(opset), "Einsum": Einsum.get_converter(opset), + "Trilu": Trilu.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index baf616a94662..951de06967fb 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _): ] +# trilu +_reg.register_strategy("trilu", strategy.trilu_strategy) + + # scatter_add @_reg.register_compute("scatter_add") def compute_scatter_add(attrs, inputs, output_type): diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8b92fdf2672d..7e8367abbb2f 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs") class FixedPointMultiplyAttrs(Attrs): """Attributes used in fixed_point_multiply operators""" + + +@tvm._ffi.register_object("relay.attrs.TriluAttrs") +class TriluAttrs(Attrs): + """Attributes used in trilu operators""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6074b0a69cc3..95558b5f3d9a 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1460,6 +1460,34 @@ def _compute_stft(attrs, inputs, output_type): return _compute_stft +# trilu +@override_native_generic_func("trilu_strategy") +def trilu_strategy(attrs, outs, out_type, target): + """trilu generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_trilu(topi.trilu), + wrap_topi_schedule(topi.generic.schedule_extern), + name="trilu.generic", + ) + return strategy + + +def wrap_compute_trilu(topi_compute): + """Wrap trilu compute""" + + def _compute_trilu(attrs, inputs, output_type): + return [ + topi_compute( + inputs[0], + inputs[1], + attrs.upper, + ) + ] + + return _compute_trilu + + # roi_pool @generic_func def schedule_roi_pool(attrs, outs, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b5d44781e5e3..e7ae5f7d8315 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1889,3 +1889,46 @@ def stft( window = _make.ones([n_fft], "int32") return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) + + +def trilu(data, k, upper=True): + """ + Given a 2-D matrix or batches of 2-D matrices, returns the + upper or lower triangular part of the tensor. + + Parameters + ---------- + data: relay.Expr + The tensor that trilu will be applied to. Must be either + a 2D matrix or a tensor of batches of 2D matrices. + + k: int + The number of diagonals above or below the main diagonal + to exclude or include. + + upper: bool, optional + If True, only upper triangular values of input are kept, + if False, the lower triangular values are kept. + + + Returns + ------- + ret : relay.Expr + The new tensor with appropriate diagonals set to zero. + + Examples + -------- + .. code-block:: python + + x = [[0, 1, 2], + [3, 4, 5], + [6, 7, 8]] + + relay.trilu(x, True, 0) = + [[0, 1, 2], + [0, 4, 5], + [0, 0, 8]] + """ + if not isinstance(k, Expr): + k = const(k, dtype="int32") + return _make.trilu(data, k, upper) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 40a98d5dd557..e12f80e2ef2e 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1003,24 +1003,25 @@ def sliding_window(data, axis, window_shape, strides): return cpp.sliding_window(data, axis, window_shape, strides) -def trilu(x, upper, k): +def trilu(data, k, upper): """ Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor. Parameters ---------- - x: tvm.te.Tensor + data: tvm.te.Tensor The tensor that trilu will be applied to. Must be either a 2D matrix or a tensor of batches of 2D matrices. + k: tvm.te.Tensor + The number of diagonals above or below the main diagonal + to exclude or include. + upper: bool If True, only upper triangular values of input are kept, if False, the lower triangular values are kept. - k: int - The number of diagonals above or below the main diagonal - to exclude or include. Returns ------- @@ -1040,6 +1041,10 @@ def trilu(x, upper, k): [0, 4, 5], [0, 0, 8]] """ + # Make sure datatype is consistent. + if k.dtype != "int32": + k = tvm.tir.Cast("int32", k) + # Check either above or below diagonal depending on upper. check_op = tvm.tir.GE if upper: @@ -1050,7 +1055,7 @@ def _apply_trilu(*indices): col_index = indices[-1] other_indices = indices[:-2] check_position = check_op(row_index, col_index - k) - value = x(*other_indices, row_index, col_index) - return tvm.tir.Select(check_position, value, tvm.tir.const(0, x.dtype)) + value = data(*other_indices, row_index, col_index) + return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) - return te.compute(x.shape, _apply_trilu, name="trilu") + return te.compute(data.shape, _apply_trilu, name="trilu") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 989ab2ad25d3..f90cd91e927b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation") .set_attr("TOpPattern", kInjective) .set_attr("TOpIsStateful", false); +// Trilu + +TVM_REGISTER_NODE_TYPE(TriluAttrs); + +bool TriluRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, k, result] + ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided"; + ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided"; + auto data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "Trilu: expect input type to be TensorType but get " << types[0]; + return false; + } + + auto k = types[1].as(); + if (k == nullptr) { + ICHECK(types[1].as()) + << "Trilu: expect k type to be TensorType but get " << types[1]; + return false; + } + + ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k; + + // Output shape is the same as input shape. + reporter->Assign(types[2], TensorType(data->shape, data->dtype)); + return true; +} + +Expr MakeTrilu(Expr data, Expr k, bool upper) { + auto attrs = make_object(); + attrs->upper = upper; + static const Op& op = Op::Get("trilu"); + return Call(op, {data, k}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu); + +RELAY_REGISTER_OP("trilu") + .describe( + R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal. + )code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor") + .add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.") + .add_type_rel("trilu", TriluRel) + .set_support_level(3) + .set_attr("TOpPattern", kElemWise); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0b2e51e54471..e500f0902c83 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5242,23 +5242,7 @@ def verify_eyelike(indata, dynamic=False): "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", - "test_tril", - "test_tril_pos", - "test_tril_square", - "test_tril_square_neg", - "test_tril_neg", - "test_tril_one_row_neg", - "test_tril_out_neg", - "test_tril_out_pos", "test_tril_zero", - "test_triu", - "test_triu_one_row", - "test_triu_out_neg_out", - "test_triu_out_pos", - "test_triu_neg", - "test_triu_pos", - "test_triu_square", - "test_triu_square_neg", "test_triu_zero", "test_unique_sorted_with_axis", "test_unique_sorted_with_axis_3d", diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f91a027de4bc..b641ba1fdb13 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -2207,5 +2207,34 @@ def test_stft( ) +def test_trilu(target="llvm", dev=tvm.cpu()): + def verify_trilu(data_shape, upper=True, k=0): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + y = relay.trilu(data, k, upper) + mod = tvm.ir.IRModule.from_expr(y) + + data_np = np.random.normal(size=data_shape).astype("float32") + tvm_res = ( + relay.create_executor("graph", mod=mod, device=dev, target=target) + .evaluate()(data_np) + .numpy() + ) + if upper: + np_res = np.triu(data_np, k) + else: + np_res = np.tril(data_np, k) + tvm.testing.assert_allclose(tvm_res, np_res) + + # Test upper and lower triangle + verify_trilu((3, 3), True, 0) + verify_trilu((3, 3), False, 0) + # Test larger matrices with offset. + verify_trilu((6, 6), True, 1) + verify_trilu((6, 6), False, 2) + verify_trilu((6, 6), False, -2) + # Test batch size + verify_trilu((8, 6, 6), False, -2) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index c4493477abb1..da4a8c033afc 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -814,7 +814,7 @@ def check_device(target, dev): def verify_trilu(input_shape, upper, k=0): x = te.placeholder(shape=input_shape, name="x", dtype="float32") - trilu_result = topi.transform.trilu(x, upper, k) + trilu_result = topi.transform.trilu(x, k, upper) def check_device(target, dev): print("Running on target: %s" % target) @@ -885,10 +885,10 @@ def test_reinterpret(): (1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) ) verify_reinterpret( - (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape) + (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape) ) From 5f7f7448a1f4027817388695c10315095068a59c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 19 Jul 2022 09:11:25 -0700 Subject: [PATCH 3/7] Fix test type. --- tests/python/topi/python/test_topi_transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index da4a8c033afc..c3155c948a8d 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -814,7 +814,8 @@ def check_device(target, dev): def verify_trilu(input_shape, upper, k=0): x = te.placeholder(shape=input_shape, name="x", dtype="float32") - trilu_result = topi.transform.trilu(x, k, upper) + k_tir = tvm.tir.const(k, dtype="int32") + trilu_result = topi.transform.trilu(x, k_tir, upper) def check_device(target, dev): print("Running on target: %s" % target) From 0721e22332713529c4affe43e7ed39ad927d35b1 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 26 Jul 2022 16:27:17 -0700 Subject: [PATCH 4/7] Add tril zero tests. --- tests/python/frontend/onnx/test_forward.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e500f0902c83..876e44a78ded 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5242,8 +5242,6 @@ def verify_eyelike(indata, dynamic=False): "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", - "test_tril_zero", - "test_triu_zero", "test_unique_sorted_with_axis", "test_unique_sorted_with_axis_3d", "test_unique_sorted_with_negative_axis", From aa468519b38968c276ddba20ce98b9dd232db9ab Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 26 Jul 2022 17:18:47 -0700 Subject: [PATCH 5/7] Add pytorch trilu integration. --- python/tvm/relay/frontend/pytorch.py | 8 ++++++++ tests/python/frontend/pytorch/test_forward.py | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1bd3232871ee..63df3fd690d9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3405,6 +3405,12 @@ def grid_sampler(self, inputs, input_types): inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners ) + def trilu(self, inputs, input_types, mode): + data = inputs[0] + k = inputs[1] if inputs[1] else 0 + upper = True if mode == "triu" else False + return _op.trilu(data, k, upper) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3661,6 +3667,8 @@ def create_convert_map(self): "aten::dot": self.dot, "aten::mv": self.mv, "aten::grid_sampler": self.grid_sampler, + "aten::triu": functools.partial(self.trilu, mode="triu"), + "aten::tril": functools.partial(self.trilu, mode="tril"), "aten::__ior__": self.make_elemwise("bitwise_or"), "aten::__iand__": self.make_elemwise("bitwise_and"), "aten::__ixor__": self.make_elemwise("bitwise_xor"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f52c7168b341..1d07c780b76b 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4616,5 +4616,15 @@ def test_fn(x, y, w): verify_model(test_fn, [x, y, w[0]]) +def test_trilu(): + def _test_trilu(op, diagonal): + return lambda inp: op(inp, diagonal) + + for op in [torch.triu, torch.tril]: + verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])]) + verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])]) + verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])]) + + if __name__ == "__main__": pytest.main([__file__]) From 0585271f35a79aec8f0d8232e47706064dce97eb Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 26 Jul 2022 17:35:07 -0700 Subject: [PATCH 6/7] Clean up torch integration. --- python/tvm/relay/frontend/pytorch.py | 31 ++-------------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 63df3fd690d9..74ea249a4785 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -318,31 +318,6 @@ def square(self, inputs, input_types): (dtype,) = input_types return _op.power(inputs[0], _expr.const(2, dtype)) - def tril(self, inputs, input_types): - data = inputs[0] - if len(inputs) == 2: - k_value = inputs[1] - else: - k_value = 0 - input_shape = self.infer_shape(data) - k1, k2 = input_shape[-2:] - k1 = k_value + 1 - diag_input = _op.zeros(input_shape, dtype=input_types[0]) - return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) - - def triu(self, inputs, input_types): - data = inputs[0] - if len(inputs) == 2: - k_value = inputs[1] - else: - k_value = 0 - input_shape = self.infer_shape(data) - k1, k2 = input_shape[-2:] - k1 = (k1 * -1) - 1 - k2 = k_value - 1 - diag_input = _op.zeros(input_shape, dtype=input_types[0]) - return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) - def lerp(self, inputs, input_types): if len(inputs) != 3: msg = "Wrong number of arguments (%d) to parse." % (len(inputs)) @@ -3573,8 +3548,8 @@ def create_convert_map(self): "aten::sqrt": self.make_unary("sqrt"), "aten::rsqrt": self.make_unary("rsqrt"), "aten::square": self.square, - "aten::tril": self.tril, - "aten::triu": self.triu, + "aten::tril": functools.partial(self.trilu, mode="tril"), + "aten::triu": functools.partial(self.trilu, mode="triu"), "aten::ceil": self.make_unary("ceil"), "aten::floor": self.make_unary("floor"), "aten::round": self.make_unary("round"), @@ -3667,8 +3642,6 @@ def create_convert_map(self): "aten::dot": self.dot, "aten::mv": self.mv, "aten::grid_sampler": self.grid_sampler, - "aten::triu": functools.partial(self.trilu, mode="triu"), - "aten::tril": functools.partial(self.trilu, mode="tril"), "aten::__ior__": self.make_elemwise("bitwise_or"), "aten::__iand__": self.make_elemwise("bitwise_and"), "aten::__ixor__": self.make_elemwise("bitwise_xor"), From 47366cf40174d74110a2170f63bb596ed8e2b47d Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 27 Jul 2022 08:43:17 -0700 Subject: [PATCH 7/7] Readded skip for zero tests. --- tests/python/frontend/onnx/test_forward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 876e44a78ded..e500f0902c83 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5242,6 +5242,8 @@ def verify_eyelike(indata, dynamic=False): "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", + "test_tril_zero", + "test_triu_zero", "test_unique_sorted_with_axis", "test_unique_sorted_with_axis_3d", "test_unique_sorted_with_negative_axis",