Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,10 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {

/*! \brief Attributes used in matrix_set_diag operator */
struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
int k1;
int k2;
bool super_diag_right_align;
bool sub_diag_right_align;

TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") {
TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(super_diag_right_align)
.set_default(true)
.describe("Bool, true iff super-diagonal is right aligned (left-padded).");
Expand Down
21 changes: 10 additions & 11 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1851,14 +1851,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, const Tensor& k1,
const Tensor& k2, bool super_diag_right_align,
bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

bool only_one_diagonal = k1 == k2;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
Expand All @@ -1868,12 +1867,10 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
for (size_t i = 0; i < ndim - 1; i++) {
diagonal_indices.push_back(iter_vars[i]);
}
if (only_one_diagonal) {
k = k1;
} else {
auto multi_diagonals = [&]() {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);
diagonal_indices.push_back(k2(0) - k);

// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
Expand All @@ -1886,13 +1883,15 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
: 0,
sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
: 0);
}
return k;
};
k = if_then_else(k1(0) == k2(0), k1(0), multi_diagonals());
diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0),
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0),
get_diag(), input(iter_vars)),
input(iter_vars));
},
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4637,6 +4637,37 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", 1)
input_shape = shape_of(inputs[0])
input_dims = infer_shape(input_shape)[0]
data_type = infer_type(inputs[0]).checked_type.dtype
k_tensor = relay.const(np.asarray(0), dtype=np.int64)
if len(inputs) == 2:
k_tensor = inputs[1]

diag_input = relay.zeros(fold_constant(input_shape), dtype=data_type)
k1, k2 = None, None
if upper == 0:
k1 = relay.add(k_tensor, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32"))
k2 = relay.expand_dims(k2, axis=0)
else:
k1 = relay.take(input_shape, relay.const(input_dims - 2, dtype="int32"))
k1 = relay.multiply(k1, relay.const(-1, dtype="int64"))
k1 = relay.subtract(k1, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64"))
k2 = relay.expand_dims(k2, axis=0)

return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -4810,6 +4841,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3281,6 +3281,11 @@ def convert_matrix_set_diag(self, op):

input_expr = self.get_tensor_expr(input_tensors[0])
diagonal_expr = self.get_tensor_expr(input_tensors[1])
diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1]))
input_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
if len(diag_shape) == len(input_shape) - 1:
diag_shape = np.insert(diag_shape, len(diag_shape) - 1, 1)
diagonal_expr = _op.reshape(diagonal_expr, diag_shape)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out
Expand All @@ -3301,13 +3306,13 @@ def convert_matrix_diag(self, op):
scale and zero points to be equal"

shape = to_int_list(self.get_tensor_shape(diagonal))
shape = np.append(shape, shape[-1])
diag_shape = np.insert(shape, len(shape) - 1, 1).astype(np.int32)
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

shape = np.append(shape, shape[-1]).astype(np.int32)
input_expr = _op.zeros(tuple(shape), dtype)
diagonal_expr = self.get_tensor_expr(diagonal)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape))
return out

def convert_densify(self, op):
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=import-outside-toplevel
"""Transform operators."""

import numpy as np
from ...tir import expr as _expr
from ..expr import Constant, Expr, Tuple, TupleWrapper, const
from . import _make
Expand Down Expand Up @@ -1409,6 +1410,11 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
k_one = k
k_two = k

if not isinstance(k_one, Expr):
k_one = const(np.asarray([k_one], dtype=np.int64))
if not isinstance(k_two, Expr):
k_two = const(np.asarray([k_two], dtype=np.int64))

super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

Expand Down
40 changes: 14 additions & 26 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3811,60 +3811,46 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs);

bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [input, diagonal, result]
ICHECK_EQ(types.size(), 3);
// `types` contains: [input, diagonal, k1, k2, result]
ICHECK_EQ(types.size(), 5);

const auto* input = types[0].as<TensorTypeNode>();
ICHECK(input);

const auto* diagonal = types[1].as<TensorTypeNode>();
ICHECK(diagonal);

const auto param = attrs.as<MatrixSetDiagAttrs>();
ICHECK_GE(param->k2, param->k1);

int d_ndims = diagonal->shape.size();
int i_ndims = input->shape.size();
const auto* k1 = types[2].as<TensorTypeNode>();
ICHECK(k1);

reporter->Assert(input->shape[i_ndims - 2] > -param->k1);
reporter->Assert(input->shape[i_ndims - 1] > param->k2);
const auto* k2 = types[3].as<TensorTypeNode>();
ICHECK(k2);

int d_ndims = diagonal->shape.size();
for (int i = 0; i < d_ndims - 2; i++) {
reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
}
if (param->k1 != param->k2) {
reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1);
} else if (d_ndims >= 2) {
reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]);
}
auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <=
input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0),
input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0),
input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0));
reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len);

reporter->Assign(types[2], TensorType(input->shape, input->dtype));
reporter->Assign(types[4], TensorType(input->shape, input->dtype));
return true;
}

Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<MatrixSetDiagAttrs>();
ICHECK(param != nullptr);
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2,
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3],
param->super_diag_right_align,
param->sub_diag_right_align)};
}

Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align,
Expr MakeMatrixSetDiag(Expr input, Expr diagonal, Expr k1, Expr k2, bool super_diag_right_align,
bool sub_diag_right_align) {
auto attrs = make_object<MatrixSetDiagAttrs>();
attrs->k1 = k1;
attrs->k2 = k2;
attrs->super_diag_right_align = super_diag_right_align;
attrs->sub_diag_right_align = sub_diag_right_align;
static const Op& op = Op::Get("matrix_set_diag");
return Call(op, {input, diagonal}, Attrs(attrs), {});
return Call(op, {input, diagonal, k1, k2}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);
Expand All @@ -3880,9 +3866,11 @@ RELAY_REGISTER_OP("matrix_set_diag")
**sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded).
)code" TVM_ADD_FILELINE)
.set_attrs_type<MatrixSetDiagAttrs>()
.set_num_inputs(2)
.set_num_inputs(4)
.add_argument("input", "Tensor", "Input Tensor.")
.add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
.add_argument("k1", "Tensor", "Lower limit (included) of the range of diagonals.")
.add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.")
.set_support_level(10)
.add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
Expand Down
5 changes: 2 additions & 3 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,10 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) {
int k1 = args[2];
int k2 = args[3];
bool super_diag_right_align = args[4];
bool sub_diag_right_align = args[5];
*rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align);
*rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align,
sub_diag_right_align);
});

TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
18 changes: 0 additions & 18 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5114,24 +5114,6 @@ def verify_eyelike(indata):
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
"test_tril",
"test_tril_pos",
"test_tril_square",
"test_tril_square_neg",
"test_tril_neg",
"test_tril_one_row_neg",
"test_tril_out_neg",
"test_tril_out_pos",
"test_tril_zero",
"test_triu",
"test_triu_one_row",
"test_triu_out_neg_out",
"test_triu_out_pos",
"test_triu_neg",
"test_triu_pos",
"test_triu_square",
"test_triu_square_neg",
"test_triu_zero",
# These unsqueeze tests work, but take 2+ hrs to run
"test_unsqueeze_three_axes",
"test_unsqueeze_two_axes",
Expand Down
8 changes: 7 additions & 1 deletion tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,13 @@ def test_matrix_set_diag():
def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
input = relay.var("input", relay.TensorType(input_shape, dtype))
diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype))
out = relay.matrix_set_diag(input, diagonal, k, align)
out = None
if len(diagonal_shape) == len(input_shape) - 1:
new_shape = list(diagonal_shape)
new_shape.insert(-1, 1)
out = relay.matrix_set_diag(input, relay.reshape(diagonal, new_shape), k, align)
else:
out = relay.matrix_set_diag(input, diagonal, k, align)

in_type = run_infer_type(input)
out_type = run_infer_type(out)
Expand Down
56 changes: 50 additions & 6 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,24 +750,68 @@ def check_device(target, dev):


def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
# input matrix that contains diagonals to be replaced
input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
# diagonal values to be placed as new diagonal values of input matrix
diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align)
# diagonals offsets
# k1 and k2 define the lower and upper limits of diagonals to be set
# where k*=0 means main diagonal, k*< 0 sub-diagonal, and k*> 0 super-diagonal
# when k is not an tuple or list, k1 will be equal to k2, meaning that only one diagonal will be replaced.
k1 = te.placeholder(shape=(1,), name="k1", dtype="int64")
# k2 defines the upper limit diagonal to be set
k2 = te.placeholder(shape=(1,), name="k2", dtype="int64")
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align)

# k can be an integer or a pair of integers representing the lower and upper limits of a matrix band;
k_one, k_two = None, None
if isinstance(k, (tuple, list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments to this test? It's a little hard to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will add some comments and update the PR.

k_one = k[0]
if len(k) >= 2:
k_two = k[1]
else:
k_two = k[0]
else:
k_one = k
k_two = k

# Generate random data for input matrix
input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype)
# Generate random data for diagonal (single or multiple diagionals)
diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype)
# Run numpy test for matrix_set_diag with random data
# output will be saved to compare with TOPI version of matrix_set_diag
out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align)

def check_device(target, dev):
dev = tvm.device(target, 0)
print("Running on target: %s" % target)
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result)
fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], target, name="matrix_set_diag")
input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype)
diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype)
out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align)
fn = tvm.build(
s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag"
)

# Convert numpy input data to TVM ND array
input_nd = tvm.nd.array(input_npy, dev)

# Convert numpy diagonal data to TVM ND array
diagonal_nd = tvm.nd.array(diagonal_npy, dev)

# Convert k1 and k2 to numpy array and then to TVM ND array
k1_nd = tvm.nd.array(np.asarray([k_one]), dev)
k2_nd = tvm.nd.array(np.asarray([k_two]), dev)

# Convert k1 and k2 to numpy array and then to TVM ND array
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev)
fn(input_nd, diagonal_nd, out_nd)

# Run TOPI test for matrix_set_diag with random data
fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd)

# Convert TOPI output to numpy
out_topi = out_nd.numpy()

# Check if Numpy version matches TOPI one
tvm.testing.assert_allclose(out_topi, out_npy)

for target, dev in tvm.testing.enabled_targets():
Expand Down