Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,13 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
Integer axis;
DataType dtype;
Integer exclusive;
TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
TVM_ATTR_FIELD(exclusive)
.describe("The first element is not included")
.set_default(NullValue<Integer>());
}
};

Expand Down
25 changes: 24 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value
from .common import infer_type, get_name


Expand Down Expand Up @@ -1075,6 +1075,28 @@ def _impl_v1(cls, inputs, attr, params):
return _op.shape_of(inputs[0], "int64")


class CumSum(OnnxOpConverter):
"""Operator converter for CumSum."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
dim = inputs[1]

if dim is not None:
dim = int(infer_value(dim, params).asnumpy())

exclusive = attr.get("exclusive", 0)
reverse = attr.get("reverse", 0)

if reverse != 0:
out = _op.reverse(data, axis=dim)
out = _op.cumsum(out, axis=dim, exclusive=exclusive)
return _op.reverse(out, axis=dim)

return _op.cumsum(data, axis=dim, exclusive=exclusive)


class Cast(OnnxOpConverter):
"""Operator converter for Cast."""

Expand Down Expand Up @@ -2736,6 +2758,7 @@ def _get_convert_map(opset):
"Resize": Resize.get_converter(opset),
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compute_scatter_nd(attrs, inputs, output_type):
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]


_reg.register_strategy("cumsum", strategy.cumsum_strategy)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def wrap_compute_cumsum(topi_compute):
"""Wrap cumsum topi compute"""

def _compute_cumsum(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]

return _compute_cumsum

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def adv_index(inputs):
return _make.adv_index(Tuple(inputs))


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.

Expand All @@ -1339,6 +1339,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.

exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.

Returns
-------
result : relay.Expr
Expand Down Expand Up @@ -1368,4 +1374,4 @@ def cumsum(data, axis=None, dtype=None):
cumsum(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype)
return _make.cumsum(data, axis, dtype, exclusive)
10 changes: 9 additions & 1 deletion python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def traverse(op):
return s


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.

Parameters
Expand All @@ -504,6 +504,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.

exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.

Returns
-------
result : tvm.te.Tensor
Expand All @@ -514,4 +520,6 @@ def cumsum(data, axis=None, dtype=None):
axis = 0
data = reshape(data, (prod(data.shape),))
axis = get_const_int(axis)
if exclusive is not None and exclusive != 0:
return exclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
21 changes: 18 additions & 3 deletions python/tvm/topi/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .math import cast


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.

Parameters
Expand All @@ -38,6 +38,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.

exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.

Returns
-------
result : tvm.te.Tensor
Expand Down Expand Up @@ -75,6 +81,9 @@ def maybe_cast(x):
elif i > axis:
axis_mul_after *= value

if exclusive is None:
exclusive = 0

def gen_ir(data_buf, out_buf):
ib = ir_builder.create()
data_buf = ib.buffer_ptr(data_buf)
Expand All @@ -84,12 +93,18 @@ def gen_ir(data_buf, out_buf):
i = fused // axis_mul_after
j = fused % axis_mul_after
base_idx = i * cumsum_axis_len * axis_mul_after + j
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
if exclusive == 0:
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
else:
out_buf[base_idx] = cast(0, dtype)
with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
k = _k + 1
cur_idx = base_idx + k * axis_mul_after
prev_idx = base_idx + (k - 1) * axis_mul_after
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
if exclusive == 0:
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
else:
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[prev_idx])

return ib.get()

Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3705,10 +3705,11 @@ bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Expr MakeCumsum(Expr data, Integer axis, DataType dtype) {
Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Integer exclusive) {
auto attrs = make_object<CumsumAttrs>();
attrs->dtype = dtype;
attrs->axis = axis;
attrs->exclusive = exclusive;
static const Op& op = Op::Get("cumsum");
return Call(op, {data}, Attrs(attrs), {});
}
Expand Down
77 changes: 77 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,82 @@ def verify_softplus(indata):
verify_softplus(input_data)


def test_cumsum():
def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
cumsum_node = onnx.helper.make_node(
"CumSum",
inputs=["X", "axis"],
outputs=["Y"],
)
if exclusive != 0:
exclusive_attr = helper.make_attribute("exclusive", exclusive)
cumsum_node.attribute.append(exclusive_attr)
if reverse != 0:
reverse_attr = helper.make_attribute("reverse", reverse)
cumsum_node.attribute.append(reverse_attr)
nodes = [
make_constant_node("axis", onnx.TensorProto.INT32, [1], [axis]),
cumsum_node,
]
if type == "float32":
tensor_type = TensorProto.FLOAT
else:
tensor_type = TensorProto.INT32
type = "int32"

graph = helper.make_graph(
nodes,
"cumsum_test",
inputs=[
helper.make_tensor_value_info("X", tensor_type, list(indata.shape)),
],
outputs=[helper.make_tensor_value_info("Y", tensor_type, list(indata.shape))],
)

model = helper.make_model(graph, producer_name="cumsum_test")

verify_with_ort_with_inputs(model, [indata], dtype=type, use_vm=True, opset=11)

data = (
np.array(
[
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0,
10.0,
11.0,
12.0,
]
)
.astype(np.float32)
.reshape((3, 4))
)

verify_cumsum(data, 0)
verify_cumsum(data, 1)
verify_cumsum(data, 0, 1, 0)
verify_cumsum(data, 1, 1, 0)
verify_cumsum(data, 0, 0, 1)
verify_cumsum(data, 1, 0, 1)
verify_cumsum(data, 1, 1, 1)
data = np.random.randn(1, 32, 32, 3).astype("float32")
verify_cumsum(data, 1)
data = np.random.randn(1, 32, 32, 3).astype("int32")
verify_cumsum(data, 0, type="int32")
verify_cumsum(data, 1, type="int32")
verify_cumsum(data, 0, 1, 0, type="int32")
verify_cumsum(data, 1, 1, 0, type="int32")
verify_cumsum(data, 0, 0, 1, type="int32")
verify_cumsum(data, 1, 0, 1, type="int32")
verify_cumsum(data, 1, 1, 1, type="int32")


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4040,3 +4116,4 @@ def verify_softplus(indata):
test_size()
test_maxunpool()
test_softplus()
test_cumsum()