diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 43166249638a..45a1caf2bd79 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -442,9 +442,13 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode { struct CumsumAttrs : public tvm::AttrsNode { Integer axis; DataType dtype; + Integer exclusive; TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") { TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue()); TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue()); + TVM_ATTR_FIELD(exclusive) + .describe("The first element is not included") + .set_default(NullValue()); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c423598a2ee7..c9140d782a2d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -34,7 +34,7 @@ from .. import ty as _ty from .common import AttrCvt, Renamer -from .common import get_relay_op, new_var, infer_shape, infer_channels +from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value from .common import infer_type, get_name @@ -1075,6 +1075,28 @@ def _impl_v1(cls, inputs, attr, params): return _op.shape_of(inputs[0], "int64") +class CumSum(OnnxOpConverter): + """Operator converter for CumSum.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + dim = inputs[1] + + if dim is not None: + dim = int(infer_value(dim, params).asnumpy()) + + exclusive = attr.get("exclusive", 0) + reverse = attr.get("reverse", 0) + + if reverse != 0: + out = _op.reverse(data, axis=dim) + out = _op.cumsum(out, axis=dim, exclusive=exclusive) + return _op.reverse(out, axis=dim) + + return _op.cumsum(data, axis=dim, exclusive=exclusive) + + class Cast(OnnxOpConverter): """Operator converter for Cast.""" @@ -2736,6 +2758,7 @@ def _get_convert_map(opset): "Resize": Resize.get_converter(opset), "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), + "CumSum": CumSum.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index fd07c98ddc1f..ba2416ff8950 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -116,7 +116,7 @@ def compute_scatter_nd(attrs, inputs, output_type): @_reg.register_compute("cumsum") def compute_cumsum(attrs, inputs, output_type): """Compute definition of cumsum""" - return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)] + return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)] _reg.register_strategy("cumsum", strategy.cumsum_strategy) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3ad75faf4bc1..af1d2552fab7 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1367,7 +1367,7 @@ def wrap_compute_cumsum(topi_compute): """Wrap cumsum topi compute""" def _compute_cumsum(attrs, inputs, _): - return [topi_compute(inputs[0], attrs.axis, attrs.dtype)] + return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)] return _compute_cumsum diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6785ff248612..e9d081eb5fb6 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1322,7 +1322,7 @@ def adv_index(inputs): return _make.adv_index(Tuple(inputs)) -def cumsum(data, axis=None, dtype=None): +def cumsum(data, axis=None, dtype=None, exclusive=None): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. @@ -1339,6 +1339,12 @@ def cumsum(data, axis=None, dtype=None): Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. + exclusive : int, optional + If set to 1 will return exclusive sum in which the first element is not + included. In other terms, if set to 1, the j-th output element would be + the sum of the first (j-1) elements. Otherwise, it would be the sum of + the first j elements. + Returns ------- result : relay.Expr @@ -1368,4 +1374,4 @@ def cumsum(data, axis=None, dtype=None): cumsum(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 2, 2, 3, 4, 4] """ - return _make.cumsum(data, axis, dtype) + return _make.cumsum(data, axis, dtype, exclusive) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 232d679840fd..0bdab100b429 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -488,7 +488,7 @@ def traverse(op): return s -def cumsum(data, axis=None, dtype=None): +def cumsum(data, axis=None, dtype=None, exclusive=None): """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. Parameters @@ -504,6 +504,12 @@ def cumsum(data, axis=None, dtype=None): Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. + exclusive : int, optional + If set to 1 will return exclusive sum in which the first element is not + included. In other terms, if set to 1, the j-th output element would be + the sum of the first (j-1) elements. Otherwise, it would be the sum of + the first j elements. + Returns ------- result : tvm.te.Tensor @@ -514,4 +520,6 @@ def cumsum(data, axis=None, dtype=None): axis = 0 data = reshape(data, (prod(data.shape),)) axis = get_const_int(axis) + if exclusive is not None and exclusive != 0: + return exclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add) return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index 855427b1c619..2013a352874d 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -22,7 +22,7 @@ from .math import cast -def cumsum(data, axis=None, dtype=None): +def cumsum(data, axis=None, dtype=None, exclusive=None): """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. Parameters @@ -38,6 +38,12 @@ def cumsum(data, axis=None, dtype=None): Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. + exclusive : int, optional + If set to 1 will return exclusive sum in which the first element is not + included. In other terms, if set to 1, the j-th output element would be + the sum of the first (j-1) elements. Otherwise, it would be the sum of + the first j elements. + Returns ------- result : tvm.te.Tensor @@ -75,6 +81,9 @@ def maybe_cast(x): elif i > axis: axis_mul_after *= value + if exclusive is None: + exclusive = 0 + def gen_ir(data_buf, out_buf): ib = ir_builder.create() data_buf = ib.buffer_ptr(data_buf) @@ -84,12 +93,18 @@ def gen_ir(data_buf, out_buf): i = fused // axis_mul_after j = fused % axis_mul_after base_idx = i * cumsum_axis_len * axis_mul_after + j - out_buf[base_idx] = maybe_cast(data_buf[base_idx]) + if exclusive == 0: + out_buf[base_idx] = maybe_cast(data_buf[base_idx]) + else: + out_buf[base_idx] = cast(0, dtype) with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k: k = _k + 1 cur_idx = base_idx + k * axis_mul_after prev_idx = base_idx + (k - 1) * axis_mul_after - out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx]) + if exclusive == 0: + out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx]) + else: + out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[prev_idx]) return ib.get() diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d44bfe6959ca..5e39b409615d 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3705,10 +3705,11 @@ bool CumsumRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Integer exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; + attrs->exclusive = exclusive; static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 515fc32ef88d..27b91dd38f8e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3964,6 +3964,82 @@ def verify_softplus(indata): verify_softplus(input_data) +def test_cumsum(): + def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): + cumsum_node = onnx.helper.make_node( + "CumSum", + inputs=["X", "axis"], + outputs=["Y"], + ) + if exclusive != 0: + exclusive_attr = helper.make_attribute("exclusive", exclusive) + cumsum_node.attribute.append(exclusive_attr) + if reverse != 0: + reverse_attr = helper.make_attribute("reverse", reverse) + cumsum_node.attribute.append(reverse_attr) + nodes = [ + make_constant_node("axis", onnx.TensorProto.INT32, [1], [axis]), + cumsum_node, + ] + if type == "float32": + tensor_type = TensorProto.FLOAT + else: + tensor_type = TensorProto.INT32 + type = "int32" + + graph = helper.make_graph( + nodes, + "cumsum_test", + inputs=[ + helper.make_tensor_value_info("X", tensor_type, list(indata.shape)), + ], + outputs=[helper.make_tensor_value_info("Y", tensor_type, list(indata.shape))], + ) + + model = helper.make_model(graph, producer_name="cumsum_test") + + verify_with_ort_with_inputs(model, [indata], dtype=type, use_vm=True, opset=11) + + data = ( + np.array( + [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + ] + ) + .astype(np.float32) + .reshape((3, 4)) + ) + + verify_cumsum(data, 0) + verify_cumsum(data, 1) + verify_cumsum(data, 0, 1, 0) + verify_cumsum(data, 1, 1, 0) + verify_cumsum(data, 0, 0, 1) + verify_cumsum(data, 1, 0, 1) + verify_cumsum(data, 1, 1, 1) + data = np.random.randn(1, 32, 32, 3).astype("float32") + verify_cumsum(data, 1) + data = np.random.randn(1, 32, 32, 3).astype("int32") + verify_cumsum(data, 0, type="int32") + verify_cumsum(data, 1, type="int32") + verify_cumsum(data, 0, 1, 0, type="int32") + verify_cumsum(data, 1, 1, 0, type="int32") + verify_cumsum(data, 0, 0, 1, type="int32") + verify_cumsum(data, 1, 0, 1, type="int32") + verify_cumsum(data, 1, 1, 1, type="int32") + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4040,3 +4116,4 @@ def verify_softplus(indata): test_size() test_maxunpool() test_softplus() + test_cumsum()