From 66ebfdb254c3e26591b1b262b5452b13b4c9f1aa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Jan 2021 18:22:09 +0900 Subject: [PATCH 01/27] Add cumsum relay/topi op --- include/tvm/relay/attrs/transform.h | 9 +++ python/tvm/relay/op/_transform.py | 13 +++- python/tvm/relay/op/strategy/generic.py | 12 +++ python/tvm/relay/op/transform.py | 7 ++ python/tvm/topi/__init__.py | 1 + python/tvm/topi/cumsum.py | 80 ++++++++++++++++++++ src/relay/op/tensor/transform.cc | 3 + tests/python/topi/python/test_topi_cumsum.py | 63 +++++++++++++++ 8 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 python/tvm/topi/cumsum.py create mode 100644 tests/python/topi/python/test_topi_cumsum.py diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index efa44e026c51..0694816feba2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -438,6 +438,15 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode { } }; // struct MatrixSetDiagAttrs +struct CumsumAttrs : public tvm::AttrsNode { + Integer axis; + DataType dtype; + TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 05ca6d2e4bb9..0288261f5b55 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -103,7 +103,7 @@ def compute_scatter_add(attrs, inputs, output_type): _reg.register_strategy("scatter_add", strategy.scatter_add_strategy) -# scatter +# scatter_nd @_reg.register_compute("scatter_nd") def compute_scatter_nd(attrs, inputs, output_type): """Compute definition of scatter_nd""" @@ -112,6 +112,17 @@ def compute_scatter_nd(attrs, inputs, output_type): _reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy) +# cumsum +@_reg.register_compute("cumsum") +def compute_cumsum(attrs, inputs, output_type): + """Compute definition of cumsum""" + return [topi.cumsum(inputs[0], inputs[1], attrs.out_shape)] + + +_reg.register_strategy("cumsum", strategy.cumsum_strategy) +_reg.register_pattern("cumsum", OpPattern.OPAQUE) +_reg.register_shape_func("cumsum", False, elemwise_shape_func) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 8dd9dc5844dd..9314f588044e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1361,3 +1361,15 @@ def threefry_split_strategy(attrs, inputs, out_type, target): name="threefry_split.generic", ) return strategy + + +@override_native_generic_func("cumsum_strategy") +def cumsum_strategy(attrs, inputs, out_type, target): + """cumsum generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + topi.cumsum, + wrap_topi_schedule(topi.generic.schedule_extern), + name="cumsum.generic", + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7e7f9b299593..b75e0106f84a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1320,3 +1320,10 @@ def adv_index(inputs): Output tensor. """ return _make.adv_index(Tuple(inputs)) + + +def cumsum(data, axis=None, dtype=None): + """ + TODO + """ + return _make.cumsum(data, axis, dtype) diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index cb94b5b86c9e..873901df62a5 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -40,6 +40,7 @@ from .scatter import * from .scatter_add import * from .argwhere import * +from .cumsum import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py new file mode 100644 index 000000000000..3870e7dadfa0 --- /dev/null +++ b/python/tvm/topi/cumsum.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Cumsum operator""" +from ..tir import decl_buffer, ir_builder +from ..te import extern +from .utils import prod +from .math import cast + + +def cumsum(data, axis=None, dtype=None): + if dtype is None: + dtype = data.dtype + + def maybe_cast(x): + if dtype != data.dtype: + return cast(x, dtype) + return x + + axis_mul_before = 1 + axis_mul_after = 1 + + if axis is None and axis != 0: + axis = 0 + cumsum_axis_len = prod(data.shape) + shape = (cumsum_axis_len,) + else: + shape = data.shape + cumsum_axis_len = shape[axis] + + if axis < 0: + axis = len(shape) + axis + + for i, value in enumerate(shape, 0): + if i < axis: + axis_mul_before *= value + elif i > axis: + axis_mul_after *= value + + def gen_ir(data_buf, out_buf): + ib = ir_builder.create() + data_buf = ib.buffer_ptr(data_buf) + out_buf = ib.buffer_ptr(out_buf) + + with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_after) as j: + base_idx = i * cumsum_axis_len * axis_mul_after + j + out_buf[base_idx] = maybe_cast(data_buf[base_idx]) + with ib.for_range(0, cumsum_axis_len - 1) as _k: + k = _k + 1 + cur_idx = base_idx + k * axis_mul_after + prev_idx = base_idx + (k - 1) * axis_mul_after + out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx]) + + return ib.get() + + out_buf = decl_buffer(shape, dtype, "out_buf") + + return extern( + [shape], + [data], + lambda ins, outs: gen_ir(ins[0], outs[0]), + dtype=dtype, + out_buffers=[out_buf], + name="cumsum_generic", + tag="cumsum_generic", + ) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ecfde359d11d..ee9d867d6e85 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3674,5 +3674,8 @@ RELAY_REGISTER_OP("adv_index") .set_attr("TOpPattern", kInjective) .set_attr("FTVMCompute", AdvIndexCompute); + +TVM_REGISTER_NODE_TYPE(CumsumAttrs); + } // namespace relay } // namespace tvm diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py new file mode 100644 index 000000000000..30ca322012bb --- /dev/null +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +from tvm import topi +import tvm.topi.testing + + +@tvm.testing.parametrize_targets +def test_cumsum(ctx, target): + def check_cumsum(np_ref, data, axis=None, dtype=None): + implementations = { + "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) + + data = np.array([2, 3, 0]) + check_cumsum(np.cumsum(data), data) + + data = np.random.randn(10, 10) + check_cumsum(np.cumsum(data), data) + check_cumsum(np.cumsum(data, axis=0), data, axis=0) + check_cumsum(np.cumsum(data, axis=1), data, axis=1) + + data = np.random.randn(10, 5, 10) + check_cumsum(np.cumsum(data), data) + check_cumsum(np.cumsum(data, axis=0), data, axis=0) + check_cumsum(np.cumsum(data, axis=1), data, axis=1) + check_cumsum(np.cumsum(data, axis=-1), data, axis=-1) + + data = np.random.rand(10) > 0.5 + data = data.astype(np.int32) + check_cumsum(np.cumsum(data, dtype=np.int32), data) + check_cumsum(np.cumsum(data), data, dtype="int64") + + data = np.random.randint(-100, 100, size=(100, 100)).astype(np.int32) + check_cumsum(np.cumsum(data, dtype=np.int32), data) + check_cumsum(np.cumsum(data), data, dtype="int64") + check_cumsum(np.cumsum(data, axis=0, dtype=np.int32), data, axis=0) + check_cumsum(np.cumsum(data, axis=1, dtype=np.int32), data, axis=1) + + data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(np.int32) + check_cumsum(np.cumsum(data), data, dtype="int64") + + +if __name__ == "__main__": + test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) From 2f7c418503a1c708e689ca665d7eaec83c018a23 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Jan 2021 18:45:00 +0900 Subject: [PATCH 02/27] relay tests working --- python/tvm/relay/op/_transform.py | 3 +- python/tvm/relay/op/strategy/generic.py | 11 ++++- python/tvm/topi/cumsum.py | 7 ++- src/relay/op/tensor/transform.cc | 48 +++++++++++++++++++- tests/python/relay/test_op_level3.py | 35 ++++++++++++++ tests/python/topi/python/test_topi_cumsum.py | 38 ++++++++-------- 6 files changed, 118 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0288261f5b55..fd07c98ddc1f 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -116,11 +116,10 @@ def compute_scatter_nd(attrs, inputs, output_type): @_reg.register_compute("cumsum") def compute_cumsum(attrs, inputs, output_type): """Compute definition of cumsum""" - return [topi.cumsum(inputs[0], inputs[1], attrs.out_shape)] + return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)] _reg.register_strategy("cumsum", strategy.cumsum_strategy) -_reg.register_pattern("cumsum", OpPattern.OPAQUE) _reg.register_shape_func("cumsum", False, elemwise_shape_func) ##################### diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 9314f588044e..3ad75faf4bc1 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1363,12 +1363,21 @@ def threefry_split_strategy(attrs, inputs, out_type, target): return strategy +def wrap_compute_cumsum(topi_compute): + """Wrap cumsum topi compute""" + + def _compute_cumsum(attrs, inputs, _): + return [topi_compute(inputs[0], attrs.axis, attrs.dtype)] + + return _compute_cumsum + + @override_native_generic_func("cumsum_strategy") def cumsum_strategy(attrs, inputs, out_type, target): """cumsum generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - topi.cumsum, + wrap_compute_cumsum(topi.cumsum), wrap_topi_schedule(topi.generic.schedule_extern), name="cumsum.generic", ) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index 3870e7dadfa0..0c6aec4fbc97 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -17,12 +17,12 @@ """Cumsum operator""" from ..tir import decl_buffer, ir_builder from ..te import extern -from .utils import prod +from .utils import prod, get_const_int from .math import cast def cumsum(data, axis=None, dtype=None): - if dtype is None: + if dtype is None or dtype == "": dtype = data.dtype def maybe_cast(x): @@ -38,6 +38,9 @@ def maybe_cast(x): cumsum_axis_len = prod(data.shape) shape = (cumsum_axis_len,) else: + if not isinstance(axis, int): + axis = get_const_int(axis) + shape = data.shape cumsum_axis_len = shape[axis] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ee9d867d6e85..d18f88f0092a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3674,8 +3674,54 @@ RELAY_REGISTER_OP("adv_index") .set_attr("TOpPattern", kInjective) .set_attr("FTVMCompute", AdvIndexCompute); - TVM_REGISTER_NODE_TYPE(CumsumAttrs); +bool CumsumRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "cumsum: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto* param = attrs.as(); + + auto dtype = param->dtype; + if (dtype.is_void()) { + dtype = data->dtype; + } + + if (param->axis.defined()) { + reporter->Assign(types[1], TensorType(data->shape, dtype)); + } else { + auto prod = data->shape[0]; + for (size_t i = 1; i < data->shape.size(); ++i) { + prod = prod * data->shape[i]; + } + reporter->Assign(types[1], TensorType({prod}, dtype)); + } + + return true; +} + +Expr MakeCumsum(Expr data, Integer axis, DataType dtype) { + auto attrs = make_object(); + attrs->dtype = dtype; + attrs->axis = axis; + static const Op& op = Op::Get("cumsum"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum); + +RELAY_REGISTER_OP("cumsum") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Cumsum", CumsumRel) + .set_attr("TOpPattern", kOpaque); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 5e44170b6428..2a16a21fcefe 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1342,6 +1342,40 @@ def verify_adv_index(data_shape, index_shapes): verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)]) +def test_cumsum(): + def verify_cumsum(data_np, np_out, axis=None, out_dtype=None): + inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype))) + + out = relay.op.cumsum(inp, axis, out_dtype) + func = relay.Function([inp], out) + + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np) + tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5) + + data = np.array([2, 3, 0]) + verify_cumsum(data, np.cumsum(data)) + verify_cumsum(data, np.cumsum(data), out_dtype="int64") + + data = np.random.randn(10, 10) + verify_cumsum(data, np.cumsum(data)) + verify_cumsum(data, np.cumsum(data, axis=0), axis=0) + verify_cumsum(data, np.cumsum(data, axis=1), axis=1) + + data = np.random.randn(10, 5, 10).astype("float32") + verify_cumsum(data, np.cumsum(data)) + verify_cumsum(data, np.cumsum(data, axis=0), axis=0) + verify_cumsum(data, np.cumsum(data, axis=1), axis=1) + verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1) + + data = np.random.rand(10) > 0.5 + data = data.astype(np.int32) + verify_cumsum(data, np.cumsum(data, dtype=np.int32)) + verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64") + + if __name__ == "__main__": test_cast() test_zeros_ones() @@ -1379,3 +1413,4 @@ def verify_adv_index(data_shape, index_shapes): test_sparse_to_dense() test_fixed_point_multiply() test_adv_index() + test_cumsum() diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index 30ca322012bb..81302e8f0dd5 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -33,30 +33,32 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): data = np.array([2, 3, 0]) check_cumsum(np.cumsum(data), data) - data = np.random.randn(10, 10) - check_cumsum(np.cumsum(data), data) - check_cumsum(np.cumsum(data, axis=0), data, axis=0) - check_cumsum(np.cumsum(data, axis=1), data, axis=1) - - data = np.random.randn(10, 5, 10) - check_cumsum(np.cumsum(data), data) - check_cumsum(np.cumsum(data, axis=0), data, axis=0) - check_cumsum(np.cumsum(data, axis=1), data, axis=1) - check_cumsum(np.cumsum(data, axis=-1), data, axis=-1) - data = np.random.rand(10) > 0.5 data = data.astype(np.int32) check_cumsum(np.cumsum(data, dtype=np.int32), data) check_cumsum(np.cumsum(data), data, dtype="int64") - data = np.random.randint(-100, 100, size=(100, 100)).astype(np.int32) - check_cumsum(np.cumsum(data, dtype=np.int32), data) - check_cumsum(np.cumsum(data), data, dtype="int64") - check_cumsum(np.cumsum(data, axis=0, dtype=np.int32), data, axis=0) - check_cumsum(np.cumsum(data, axis=1, dtype=np.int32), data, axis=1) + for in_dtype in ["float32", "float64"]: + data = np.random.randn(10, 10).astype(in_dtype) + check_cumsum(np.cumsum(data), data) + check_cumsum(np.cumsum(data, axis=0), data, axis=0) + check_cumsum(np.cumsum(data, axis=1), data, axis=1) - data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(np.int32) - check_cumsum(np.cumsum(data), data, dtype="int64") + data = np.random.randn(10, 5, 10).astype(in_dtype) + check_cumsum(np.cumsum(data), data) + check_cumsum(np.cumsum(data, axis=0), data, axis=0) + check_cumsum(np.cumsum(data, axis=1), data, axis=1) + check_cumsum(np.cumsum(data, axis=-1), data, axis=-1) + + for in_dtype in ["int32", "int64"]: + data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype) + check_cumsum(np.cumsum(data, dtype=in_dtype), data) + check_cumsum(np.cumsum(data), data, dtype="int64") + check_cumsum(np.cumsum(data, axis=0, dtype=in_dtype), data, axis=0) + check_cumsum(np.cumsum(data, axis=1, dtype=in_dtype), data, axis=1) + + data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(in_dtype) + check_cumsum(np.cumsum(data), data, dtype="int64") if __name__ == "__main__": From e2d498b7712570802c917641e32adf7d113e9b03 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Jan 2021 18:45:44 +0900 Subject: [PATCH 03/27] add torch frontend converter --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++++ tests/python/frontend/pytorch/test_forward.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 991e3a8a0032..d991f31eba16 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2081,6 +2081,16 @@ def is_floating_point(self, inputs, input_types): is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) + def cumsum(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + dtype = inputs[2] + + if inputs[2] is not None: + dtype = _convert_dtype_value(inputs[2]) + + return _op.cumsum(data, axis=dim, dtype=dtype) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2278,6 +2288,7 @@ def create_convert_map(self): "aten::__not__": self.logical_not, "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, + "aten::cumsum": self.cumsum, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 7cdd450448ca..8873196d7272 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3452,6 +3452,21 @@ def test_hard_swish(): verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) +def test_cumsum(): + def test_fn(dim, dtype=None): + return lambda x: torch.cumsum(x, dim=dim, dtype=dtype) + + inp = torch.randint(0, 100, (10000,), dtype=torch.int32) + verify_model(test_fn(0), [inp]) + verify_model(test_fn(0), [inp.to(torch.int64)]) + verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)]) + + inp = torch.randn((100, 100), dtype=torch.float32) + verify_model(test_fn(dim=0, dtype=torch.float64), [inp]) + verify_model(test_fn(dim=1), [inp]) +>>>>>>> d30410e10... add torch frontend converter + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3580,6 +3595,7 @@ def test_hard_swish(): test_forward_scatter() test_numel() test_bincount() + test_cumsum() # Model tests test_resnet18() From 205a9a09e57f8c36c62058fcd3a6b3694395b8ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Jan 2021 18:46:11 +0900 Subject: [PATCH 04/27] fix for importing detr --- python/tvm/relay/frontend/pytorch.py | 20 ++++++++++++++----- tests/python/frontend/pytorch/test_forward.py | 10 ++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d991f31eba16..be12d0f46038 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -399,10 +399,7 @@ def slice(self, inputs, input_types): begin = [0] * ndim dim = int(inputs[1]) stride = int(inputs[4]) - if isinstance(inputs[2], _expr.Call): - begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) - else: - begin[dim] = int(inputs[2]) + begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) # Process begin if not isinstance(begin[dim], int): @@ -551,7 +548,13 @@ def reciprocal(self, inputs, input_types): def repeat(self, inputs, input_types): data = inputs[0] - reps = inputs[1] + reps = [] + for r in inputs[1]: + if isinstance(r, int): + reps.append(r) + else: + reps.append(int(_infer_value(r, {}).asnumpy())) + return _op.transform.tile(data, reps=reps) def repeat_interleave(self, inputs, input_types): @@ -2091,6 +2094,12 @@ def cumsum(self, inputs, input_types): return _op.cumsum(data, axis=dim, dtype=dtype) + def masked_fill(self, inputs, input_types): + mask = inputs[1] + value = _op.cast(_wrap_const(inputs[2]), input_types[0]) + + return _op.where(mask, value, inputs[0]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2289,6 +2298,7 @@ def create_convert_map(self): "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, "aten::cumsum": self.cumsum, + "aten::masked_fill": self.masked_fill, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8873196d7272..e92670cbbf4a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3467,6 +3467,15 @@ def test_fn(dim, dtype=None): >>>>>>> d30410e10... add torch frontend converter +def test_masked_fill(): + def test_fn(x, mask): + return torch.masked_fill(x, mask, 0.0) + + inp = torch.randn(100, 100) + verify_model(test_fn, [inp, inp > 0.5]) + verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3596,6 +3605,7 @@ def test_fn(dim, dtype=None): test_numel() test_bincount() test_cumsum() + test_masked_fill() # Model tests test_resnet18() From 91a2c6639a4287fe8d918e2dd611507188c0f1ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 07:21:01 +0900 Subject: [PATCH 05/27] fix bad merge --- tests/python/frontend/pytorch/test_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e92670cbbf4a..3627df3e029b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3464,7 +3464,9 @@ def test_fn(dim, dtype=None): inp = torch.randn((100, 100), dtype=torch.float32) verify_model(test_fn(dim=0, dtype=torch.float64), [inp]) verify_model(test_fn(dim=1), [inp]) ->>>>>>> d30410e10... add torch frontend converter + + inp = torch.randn((100, 100), dtype=torch.float32) > 0.5 + verify_model(test_fn(dim=0, dtype=torch.int32), [inp]) def test_masked_fill(): From 49f7ce72cbb8e9c77ceb232d950d2c770037cc32 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 20:44:15 +0900 Subject: [PATCH 06/27] begin cuda cumsum --- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/scan.py | 10 ++++++++++ tests/python/topi/python/test_topi_cumsum.py | 4 +++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 42bf980bec4c..e0ff5a12a9b2 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -56,3 +56,4 @@ from .correlation import * from .sparse import * from .argwhere import * +from .scan import * diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index f19e4a14239a..236f079e9f59 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -404,3 +404,13 @@ def traverse(op): for out in outs: traverse(out.op) return s + + +def cumsum(data, axis=None, dtype=None): + if axis is None and axis != 0: + axis = 0 + ex_scan = exclusive_scan(data, axis, output_dtype=dtype) + if dtype is not None and data.dtype != dtype: + data = cast(data, dtype) + in_scan = data + ex_scan + return in_scan diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index 81302e8f0dd5..b03b66893571 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -26,6 +26,7 @@ def test_cumsum(ctx, target): def check_cumsum(np_ref, data, axis=None, dtype=None): implementations = { "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), + "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -62,4 +63,5 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): if __name__ == "__main__": - test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) + # test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) + test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) From eca7420d14fa5a1612fd747f28b8a2a2c720b549 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 21:13:42 +0900 Subject: [PATCH 07/27] support non innermost axis --- python/tvm/topi/cuda/scan.py | 110 +++++++++++++++++++++-------------- python/tvm/topi/cuda/sort.py | 7 +-- python/tvm/topi/cumsum.py | 2 +- python/tvm/topi/utils.py | 5 ++ 4 files changed, 72 insertions(+), 52 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 236f079e9f59..dd459b679450 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -19,8 +19,8 @@ import tvm from tvm import te from tvm._ffi import get_global_func -from ..transform import expand_dims, squeeze -from ..utils import ceil_div +from ..transform import expand_dims, squeeze, transpose, reshape +from ..utils import ceil_div, swap, prod from ..math import cast from .. import tag from .injective import schedule_injective_from_existing @@ -319,62 +319,79 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): Returned if return_reduction is True. """ # TODO(masahi): Support other binary operators - ndim = len(data.shape) - if axis < 0: - axis += ndim - assert axis == ndim - 1, "Only support scan on the inner most axis." + def do_scan(data, output_dtype): + target = tvm.target.Target.current() + if target and target.kind.name == "cuda" and is_thrust_available(): + return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction) + + if ndim == 1: + # TIR exclusive scan accepts only 2D inputs. + data = expand_dims(data, axis=0) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + + if len(data.shape) == 2: + if return_reduction: + output, reduction = te.extern( + [data.shape, (data.shape[0],)], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), + dtype=[data.dtype, output_dtype], + in_buffers=[data_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + else: + output = te.extern( + [data.shape], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + dtype=[output_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + reduction = None + else: + assert False, "Unsupported dimension {}".format(ndim) + + if ndim == 1: + output = squeeze(output, 0) + if return_reduction: + reduction = squeeze(reduction, 0) + + if return_reduction: + return output, reduction + + return output if output_dtype is None: output_dtype = data.dtype - target = tvm.target.Target.current() - if target and target.kind.name == "cuda" and is_thrust_available(): - return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction) - - if ndim == 1: - # TIR exclusive scan accepts only 2D inputs. - data = expand_dims(data, axis=0) + ndim = len(data.shape) + if axis < 0: + axis += ndim - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + data = transpose(data, axes) - if len(data.shape) == 2: - if return_reduction: - output, reduction = te.extern( - [data.shape, (data.shape[0],)], - [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), - dtype=[data.dtype, output_dtype], - in_buffers=[data_buf], - name="exclusive_scan", - tag="exclusive_scan_gpu", - ) - else: - output = te.extern( - [data.shape], - [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), - dtype=[output_dtype], - in_buffers=[data_buf], - out_buffers=[output_buf], - name="exclusive_scan", - tag="exclusive_scan_gpu", - ) - reduction = None + if return_reduction: + output, reduction = do_scan(data, output_dtype) else: - assert False, "Unsupported dimension {}".format(ndim) + output = do_scan(data, output_dtype) - if ndim == 1: - output = squeeze(output, 0) - if return_reduction: - reduction = squeeze(reduction, 0) + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + output = transpose(output, axes) if return_reduction: return output, reduction return output - def schedule_scan(outs): """Schedule for scan operator. @@ -407,8 +424,11 @@ def traverse(op): def cumsum(data, axis=None, dtype=None): - if axis is None and axis != 0: + if axis is None: axis = 0 + cumsum_axis_len = prod(data.shape) + data = reshape(data, (cumsum_axis_len,)) + ex_scan = exclusive_scan(data, axis, output_dtype=dtype) if dtype is not None and data.dtype != dtype: data = cast(data, dtype) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 18340385205e..c0f076fb6065 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,12 +23,7 @@ from .injective import schedule_injective_from_existing from ..transform import strided_slice, transpose from .. import tag -from ..utils import ceil_div - - -def swap(arr, axis): - """ swap arr[axis] and arr[-1] """ - return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] +from ..utils import ceil_div, swap def _schedule_sort(outs): diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index 0c6aec4fbc97..e69c6cb8afe7 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -33,7 +33,7 @@ def maybe_cast(x): axis_mul_before = 1 axis_mul_after = 1 - if axis is None and axis != 0: + if axis is None: axis = 0 cumsum_axis_len = prod(data.shape) shape = (cumsum_axis_len,) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index dfc226f0c331..cd9f0c61c854 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -492,3 +492,8 @@ def is_empty_shape(shape): def ceil_div(a, b): """Return ceil division of a by b""" return tvm.tir.indexdiv(a + (b - 1), b) + + +def swap(arr, axis): + """ swap arr[axis] and arr[-1] """ + return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] From 0e6b2c5b2297b3a6d96362a918911ab25d01e798 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 21:38:08 +0900 Subject: [PATCH 08/27] support rank higher than 3 --- python/tvm/topi/cuda/scan.py | 57 ++++++++++++++-------------- src/runtime/contrib/thrust/thrust.cu | 23 +++++++++-- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index dd459b679450..f20261081a4c 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -41,8 +41,8 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): 1D Buffer of size [batch_size], to store the sum of each row. """ - batch_size = data.shape[0] - scan_axis_size = data.shape[1] + batch_size = prod(data.shape[:-1]) + scan_axis_size = data.shape[-1] ib = tvm.tir.ir_builder.create() @@ -76,7 +76,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * nthread_tx + tx with ib.if_scope(tid < scan_axis_size): - output[by, tid] = data[by, tid] + output[by * scan_axis_size + tid] = cast(data[by * scan_axis_size + tid], out_dtype) nthread_tx = max_threads nthread_bx = ceil_div(scan_axis_size, max_threads) @@ -331,31 +331,28 @@ def do_scan(data, output_dtype): data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) - if len(data.shape) == 2: - if return_reduction: - output, reduction = te.extern( - [data.shape, (data.shape[0],)], - [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), - dtype=[data.dtype, output_dtype], - in_buffers=[data_buf], - name="exclusive_scan", - tag="exclusive_scan_gpu", - ) - else: - output = te.extern( - [data.shape], - [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), - dtype=[output_dtype], - in_buffers=[data_buf], - out_buffers=[output_buf], - name="exclusive_scan", - tag="exclusive_scan_gpu", - ) - reduction = None + if return_reduction: + output, reduction = te.extern( + [data.shape, (data.shape[0],)], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), + dtype=[data.dtype, output_dtype], + in_buffers=[data_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) else: - assert False, "Unsupported dimension {}".format(ndim) + output = te.extern( + [data.shape], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + dtype=[output_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + reduction = None if ndim == 1: output = squeeze(output, 0) @@ -392,6 +389,7 @@ def do_scan(data, output_dtype): return output + def schedule_scan(outs): """Schedule for scan operator. @@ -426,11 +424,12 @@ def traverse(op): def cumsum(data, axis=None, dtype=None): if axis is None: axis = 0 - cumsum_axis_len = prod(data.shape) - data = reshape(data, (cumsum_axis_len,)) + data = reshape(data, (prod(data.shape),)) ex_scan = exclusive_scan(data, axis, output_dtype=dtype) + if dtype is not None and data.dtype != dtype: data = cast(data, dtype) + in_scan = data + ex_scan return in_scan diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 4e3e3a81af1a..1121fa85a7aa 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -275,7 +275,10 @@ void thrust_scan(DLTensor* data, if (scan_size == 0) return; - if (data->ndim == 1 || (data->ndim == 2 && data->shape[0] == 1)) { + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; + + if (size == static_cast(data->shape[data->ndim - 1])) { if (exclusive) { thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); } else { @@ -294,8 +297,6 @@ void thrust_scan(DLTensor* data, return i / scan_size; }; // NOLINT(*) auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); - int64_t size = 1; - for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; if (exclusive) { thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); @@ -320,18 +321,34 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") thrust_scan(data, output, exclusive); } else if (out_dtype == "int64") { thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (in_dtype == "int64") { if (out_dtype == "int64") { thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (in_dtype == "float32") { if (out_dtype == "float32") { thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "float64") { + if (out_dtype == "float4") { + thrust_scan(data, output, exclusive); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } From 2c33718fe07e452792e4c579f8bd5a74a3d8b315 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 22:03:43 +0900 Subject: [PATCH 09/27] making binop parameter --- python/tvm/topi/cuda/scan.py | 52 ++++++++++++++++------------ src/runtime/contrib/thrust/thrust.cu | 4 +-- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index f20261081a4c..3bcc05053b62 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -26,7 +26,10 @@ from .injective import schedule_injective_from_existing -def exclusive_sum_scan2d_ir(data, output, reduction=None): +binop_name_to_func = {"sum": tvm.tir.generic.add} + + +def exclusive_sum_scan2d_ir(data, output, reduction=None, binop="sum"): """Low level IR to do exclusive sum scan along rows of 2D input. Parameters @@ -157,7 +160,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): return ib.get() -def get_reduction_from_exclusive_scan(data, ex_scan_output): +def get_reduction_from_exclusive_scan(data, ex_scan_output, binop="sum"): """Return the sum of the last element of data and the exclusive scan output. The is the reduction of data along each row (for 2-D case). @@ -182,8 +185,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output): ex_scan_output = expand_dims(ex_scan_output, axis=0) def ir(data, data_ex_scan, reduction): - batch_size = data.shape[0] - num_anchors = data.shape[1] + batch_size = prod(data.shape[:-1]) + scan_axis_size = data.shape[-1] ib = tvm.tir.ir_builder.create() @@ -201,14 +204,13 @@ def ir(data, data_ex_scan, reduction): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): - with ib.if_scope(num_anchors > 0): - reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + with ib.if_scope(scan_axis_size > 0): + reduction[tid] = data_ex_scan[tid * scan_axis_size + scan_axis_size - 1] + data[tid, scan_axis_size - 1] with ib.else_scope(): reduction[tid] = 0 return ib.get() - assert len(data.shape) == 2, "Only 2D input supported for now" data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) ex_scan_output_buf = tvm.tir.decl_buffer( ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8 @@ -235,7 +237,7 @@ def is_thrust_available(): return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None -def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): +def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, binop="sum"): """Do exclusive scan on 1D input or along rows of 2D input, using thrust. Parameters @@ -288,7 +290,7 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): return output -def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): +def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None, binop="sum"): """Do exclusive scan on 1D input or along rows of 2D input. Parameters @@ -308,6 +310,8 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): output_dtype: string, optional The dtype of the output scan tensor. If not provided, the dtype of the input is used. + biop: TODO + Returns ------- output : tvm.te.Tensor @@ -321,11 +325,11 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): # TODO(masahi): Support other binary operators def do_scan(data, output_dtype): target = tvm.target.Target.current() - if target and target.kind.name == "cuda" and is_thrust_available(): - return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction) + # if target and target.kind.name == "cuda" and is_thrust_available(): + # return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop) if ndim == 1: - # TIR exclusive scan accepts only 2D inputs. + # TIR exclusive scan accepts only 2D or higher-rank inputs. data = expand_dims(data, axis=0) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) @@ -335,7 +339,7 @@ def do_scan(data, output_dtype): output, reduction = te.extern( [data.shape, (data.shape[0],)], [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1], binop=binop), dtype=[data.dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", @@ -345,7 +349,7 @@ def do_scan(data, output_dtype): output = te.extern( [data.shape], [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], binop=binop), dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], @@ -390,6 +394,16 @@ def do_scan(data, output_dtype): return output +def inclusive_scan(data, axis=None, output_dtype=None, binop="sum"): + """TODO""" + ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, binop=binop) + + if output_dtype is not None and data.dtype != output_dtype: + data = cast(data, output_dtype) + + return binop_name_to_func[binop](data, ex_scan) + + def schedule_scan(outs): """Schedule for scan operator. @@ -422,14 +436,8 @@ def traverse(op): def cumsum(data, axis=None, dtype=None): + """TODO""" if axis is None: axis = 0 data = reshape(data, (prod(data.shape),)) - - ex_scan = exclusive_scan(data, axis, output_dtype=dtype) - - if dtype is not None and data.dtype != dtype: - data = cast(data, dtype) - - in_scan = data + ex_scan - return in_scan + return inclusive_scan(data, axis, output_dtype=dtype, binop="sum") diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 1121fa85a7aa..15db3e66b64d 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -275,7 +275,7 @@ void thrust_scan(DLTensor* data, if (scan_size == 0) return; - int64_t size = 1; + size_t size = 1; for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; if (size == static_cast(data->shape[data->ndim - 1])) { @@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (in_dtype == "float64") { - if (out_dtype == "float4") { + if (out_dtype == "float64") { thrust_scan(data, output, exclusive); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; From 7cf958ee078fff57073f8cc0420ba3769391c924 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 22:56:17 +0900 Subject: [PATCH 10/27] fix overflow issue in thrust scan --- src/runtime/contrib/thrust/thrust.cu | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 15db3e66b64d..345fd1a73edb 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -278,9 +278,19 @@ void thrust_scan(DLTensor* data, size_t size = 1; for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; + const bool need_cast = std::is_same::value == false; + + auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) { + return static_cast(v); + }); // NOLINT(*) + if (size == static_cast(data->shape[data->ndim - 1])) { - if (exclusive) { + if (exclusive && need_cast) { + thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr); + } else if (exclusive && !need_cast) { thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } else if (!exclusive && need_cast) { + thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr); } else { thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); } @@ -291,15 +301,19 @@ void thrust_scan(DLTensor* data, // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., // without materializing the sequence vector - auto counting_iter = thrust::counting_iterator(0); + auto counting_iter = thrust::counting_iterator(0); // Without __host__ annotation, cub crashes - auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) { + auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) { return i / scan_size; }; // NOLINT(*) auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); - if (exclusive) { + if (exclusive && need_cast) { + thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr); + } else if (exclusive && !need_cast) { thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } else if (!exclusive && need_cast) { + thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr); } else { thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); } From 0b056a16fa3fc179b9afb0321521caa2849a4be2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 22:58:24 +0900 Subject: [PATCH 11/27] generic binop parameter working --- python/tvm/topi/cuda/scan.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 3bcc05053b62..33e678a714fd 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -29,7 +29,7 @@ binop_name_to_func = {"sum": tvm.tir.generic.add} -def exclusive_sum_scan2d_ir(data, output, reduction=None, binop="sum"): +def exclusive_scan_ir(data, output, reduction=None, binop="sum"): """Low level IR to do exclusive sum scan along rows of 2D input. Parameters @@ -90,6 +90,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None, binop="sum"): lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" ) + op = binop_name_to_func[binop] with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width @@ -114,9 +115,10 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None, binop="sum"): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.te.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): - output[by * scan_axis_size + end[0] - 1] += output[ - by * scan_axis_size + middle[0] - 1 - ] + output[by * scan_axis_size + end[0] - 1] = op( + output[by * scan_axis_size + end[0] - 1], + output[by * scan_axis_size + middle[0] - 1], + ) # Down Sweep of exclusive scan with ib.new_scope(): @@ -156,7 +158,9 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None, binop="sum"): output[by * scan_axis_size + middle[0] - 1] = output[ by * scan_axis_size + end[0] - 1 ] - output[by * scan_axis_size + end[0] - 1] += tmp[0] + output[by * scan_axis_size + end[0] - 1] = op( + output[by * scan_axis_size + end[0] - 1], tmp[0] + ) return ib.get() @@ -205,7 +209,10 @@ def ir(data, data_ex_scan, reduction): tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(scan_axis_size > 0): - reduction[tid] = data_ex_scan[tid * scan_axis_size + scan_axis_size - 1] + data[tid, scan_axis_size - 1] + reduction[tid] = binop_name_to_func[binop]( + data_ex_scan[tid * scan_axis_size + scan_axis_size - 1], + data[tid, scan_axis_size - 1], + ) with ib.else_scope(): reduction[tid] = 0 @@ -269,11 +276,12 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino """ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + binop_to_thrust_func_name = {"sum": "tvm.contrib.thrust.sum_scan"} output = te.extern( [data.shape], [data], lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive + binop_to_thrust_func_name[binop], ins[0], outs[0], exclusive ), dtype=[output_dtype], in_buffers=[data_buf], @@ -284,7 +292,7 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino if return_reduction: assert exclusive, "return_reduction should be False for inclusive scan" - reduction = get_reduction_from_exclusive_scan(data, output) + reduction = get_reduction_from_exclusive_scan(data, output, binop) return output, reduction return output @@ -325,8 +333,10 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None, bin # TODO(masahi): Support other binary operators def do_scan(data, output_dtype): target = tvm.target.Target.current() - # if target and target.kind.name == "cuda" and is_thrust_available(): - # return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop) + if target and target.kind.name == "cuda" and is_thrust_available(): + return scan_thrust( + data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop + ) if ndim == 1: # TIR exclusive scan accepts only 2D or higher-rank inputs. @@ -339,7 +349,7 @@ def do_scan(data, output_dtype): output, reduction = te.extern( [data.shape, (data.shape[0],)], [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1], binop=binop), + lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], outs[1], binop=binop), dtype=[data.dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", @@ -349,7 +359,7 @@ def do_scan(data, output_dtype): output = te.extern( [data.shape], [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], binop=binop), + lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], binop=binop), dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], From a8324fa9b7c2e8ec0f6c36c0aa19d21e03d789f6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 23:36:20 +0900 Subject: [PATCH 12/27] relay test working --- python/tvm/relay/op/strategy/cuda.py | 12 ++++++++++++ python/tvm/topi/cuda/scan.py | 15 +++++++++------ tests/python/relay/test_op_level3.py | 14 ++++++++------ tests/python/topi/python/test_topi_cumsum.py | 2 +- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 3863df0fd831..346e93445f1c 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -996,3 +996,15 @@ def argwhere_strategy_cuda(attrs, inputs, out_type, target): name="argwhere.cuda", ) return strategy + + +@cumsum_strategy.register(["cuda", "gpu"]) +def cumsum_strategy_cuda(attrs, inputs, out_type, target): + """cumsum cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_cumsum(topi.cuda.cumsum), + wrap_topi_schedule(topi.cuda.schedule_scan), + name="cumsum.cuda", + ) + return strategy diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 33e678a714fd..2836c1d351a9 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -20,7 +20,7 @@ from tvm import te from tvm._ffi import get_global_func from ..transform import expand_dims, squeeze, transpose, reshape -from ..utils import ceil_div, swap, prod +from ..utils import ceil_div, swap, prod, get_const_int from ..math import cast from .. import tag from .injective import schedule_injective_from_existing @@ -42,6 +42,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"): reduction: Buffer, optional 1D Buffer of size [batch_size], to store the sum of each row. + + binop : TODO """ batch_size = prod(data.shape[:-1]) @@ -286,8 +288,8 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], - name="exclusive_sum_scan2d", - tag="exclusive_sum_scan2d_gpu", + name="exclusive_scan_thrust", + tag="exclusive_scan_thrust_gpu", ) if return_reduction: @@ -378,7 +380,7 @@ def do_scan(data, output_dtype): return output - if output_dtype is None: + if output_dtype is None or output_dtype == "": output_dtype = data.dtype ndim = len(data.shape) @@ -404,11 +406,11 @@ def do_scan(data, output_dtype): return output -def inclusive_scan(data, axis=None, output_dtype=None, binop="sum"): +def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"): """TODO""" ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, binop=binop) - if output_dtype is not None and data.dtype != output_dtype: + if output_dtype is not None and data.dtype != output_dtype and output_dtype != "": data = cast(data, output_dtype) return binop_name_to_func[binop](data, ex_scan) @@ -450,4 +452,5 @@ def cumsum(data, axis=None, dtype=None): if axis is None: axis = 0 data = reshape(data, (prod(data.shape),)) + axis = get_const_int(axis) return inclusive_scan(data, axis, output_dtype=dtype, binop="sum") diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 2a16a21fcefe..acc8115a55a2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1311,6 +1311,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) +@tvm.testing.uses_gpu def test_adv_index(): def verify_adv_index(data_shape, index_shapes): dtype = "float32" @@ -1342,8 +1343,9 @@ def verify_adv_index(data_shape, index_shapes): verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)]) +@tvm.testing.uses_gpu def test_cumsum(): - def verify_cumsum(data_np, np_out, axis=None, out_dtype=None): + def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e-5): inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype))) out = relay.op.cumsum(inp, axis, out_dtype) @@ -1353,7 +1355,7 @@ def verify_cumsum(data_np, np_out, axis=None, out_dtype=None): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np) - tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5) + tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol) data = np.array([2, 3, 0]) verify_cumsum(data, np.cumsum(data)) @@ -1365,10 +1367,10 @@ def verify_cumsum(data_np, np_out, axis=None, out_dtype=None): verify_cumsum(data, np.cumsum(data, axis=1), axis=1) data = np.random.randn(10, 5, 10).astype("float32") - verify_cumsum(data, np.cumsum(data)) - verify_cumsum(data, np.cumsum(data, axis=0), axis=0) - verify_cumsum(data, np.cumsum(data, axis=1), axis=1) - verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1) + verify_cumsum(data, np.cumsum(data), rtol=1e-4, atol=1e-4) + verify_cumsum(data, np.cumsum(data, axis=0), axis=0, rtol=1e-4, atol=1e-4) + verify_cumsum(data, np.cumsum(data, axis=1), axis=1, rtol=1e-4, atol=1e-4) + verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4) data = np.random.rand(10) > 0.5 data = data.astype(np.int32) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index b03b66893571..eed74bc6916e 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -63,5 +63,5 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): if __name__ == "__main__": - # test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) + test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) From bb6340e826dce1b0060bb941326ef5baaaea6896 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 23:52:35 +0900 Subject: [PATCH 13/27] fixed for bool input --- src/runtime/contrib/thrust/thrust.cu | 14 +++++++++++++- tests/python/topi/python/test_topi_cumsum.py | 3 +++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 345fd1a73edb..a602dced9391 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -330,7 +330,19 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") auto in_dtype = DLDataType2String(data->dtype); auto out_dtype = DLDataType2String(output->dtype); - if (in_dtype == "int32") { + if (in_dtype == "bool") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "int32") { if (out_dtype == "int32") { thrust_scan(data, output, exclusive); } else if (out_dtype == "int64") { diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index eed74bc6916e..a205199ef902 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -39,6 +39,9 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data) check_cumsum(np.cumsum(data), data, dtype="int64") + data = np.random.rand(10) > 0.5 + check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") + for in_dtype in ["float32", "float64"]: data = np.random.randn(10, 10).astype(in_dtype) check_cumsum(np.cumsum(data), data) From ae434e5c04aa0548e0836d8ff6e28dc6c66b129f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 25 Jan 2021 11:30:40 +0900 Subject: [PATCH 14/27] remove pytorch change --- python/tvm/relay/frontend/pytorch.py | 31 +++---------------- tests/python/frontend/pytorch/test_forward.py | 28 ----------------- 2 files changed, 5 insertions(+), 54 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index be12d0f46038..991e3a8a0032 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -399,7 +399,10 @@ def slice(self, inputs, input_types): begin = [0] * ndim dim = int(inputs[1]) stride = int(inputs[4]) - begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) + if isinstance(inputs[2], _expr.Call): + begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) + else: + begin[dim] = int(inputs[2]) # Process begin if not isinstance(begin[dim], int): @@ -548,13 +551,7 @@ def reciprocal(self, inputs, input_types): def repeat(self, inputs, input_types): data = inputs[0] - reps = [] - for r in inputs[1]: - if isinstance(r, int): - reps.append(r) - else: - reps.append(int(_infer_value(r, {}).asnumpy())) - + reps = inputs[1] return _op.transform.tile(data, reps=reps) def repeat_interleave(self, inputs, input_types): @@ -2084,22 +2081,6 @@ def is_floating_point(self, inputs, input_types): is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) - def cumsum(self, inputs, input_types): - data = inputs[0] - dim = inputs[1] - dtype = inputs[2] - - if inputs[2] is not None: - dtype = _convert_dtype_value(inputs[2]) - - return _op.cumsum(data, axis=dim, dtype=dtype) - - def masked_fill(self, inputs, input_types): - mask = inputs[1] - value = _op.cast(_wrap_const(inputs[2]), input_types[0]) - - return _op.where(mask, value, inputs[0]) - # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2297,8 +2278,6 @@ def create_convert_map(self): "aten::__not__": self.logical_not, "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, - "aten::cumsum": self.cumsum, - "aten::masked_fill": self.masked_fill, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3627df3e029b..7cdd450448ca 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3452,32 +3452,6 @@ def test_hard_swish(): verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) -def test_cumsum(): - def test_fn(dim, dtype=None): - return lambda x: torch.cumsum(x, dim=dim, dtype=dtype) - - inp = torch.randint(0, 100, (10000,), dtype=torch.int32) - verify_model(test_fn(0), [inp]) - verify_model(test_fn(0), [inp.to(torch.int64)]) - verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)]) - - inp = torch.randn((100, 100), dtype=torch.float32) - verify_model(test_fn(dim=0, dtype=torch.float64), [inp]) - verify_model(test_fn(dim=1), [inp]) - - inp = torch.randn((100, 100), dtype=torch.float32) > 0.5 - verify_model(test_fn(dim=0, dtype=torch.int32), [inp]) - - -def test_masked_fill(): - def test_fn(x, mask): - return torch.masked_fill(x, mask, 0.0) - - inp = torch.randn(100, 100) - verify_model(test_fn, [inp, inp > 0.5]) - verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) - - if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3606,8 +3580,6 @@ def test_fn(x, mask): test_forward_scatter() test_numel() test_bincount() - test_cumsum() - test_masked_fill() # Model tests test_resnet18() From 3afe514748451e8c6e09bceedc66dd22a70e8ebb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 25 Jan 2021 11:36:38 +0900 Subject: [PATCH 15/27] fix pylint --- python/tvm/topi/cumsum.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index e69c6cb8afe7..7df2fcff9aa9 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """Cumsum operator""" from ..tir import decl_buffer, ir_builder from ..te import extern @@ -22,6 +23,7 @@ def cumsum(data, axis=None, dtype=None): + """TODO""" if dtype is None or dtype == "": dtype = data.dtype From 55065dfd798eaef66e765db1cd9be331d89bc162 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 25 Jan 2021 15:09:10 +0900 Subject: [PATCH 16/27] doc update --- include/tvm/relay/attrs/transform.h | 3 +- python/tvm/relay/op/transform.py | 22 +++- python/tvm/topi/cuda/scan.py | 107 +++++++++++++------ python/tvm/topi/cumsum.py | 22 +++- src/relay/op/tensor/transform.cc | 2 + tests/python/contrib/test_thrust.py | 4 +- tests/python/topi/python/test_topi_cumsum.py | 2 + 7 files changed, 125 insertions(+), 37 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 0694816feba2..43166249638a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -438,12 +438,13 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode { } }; // struct MatrixSetDiagAttrs +/*! \brief Attributes used in cumsum operator */ struct CumsumAttrs : public tvm::AttrsNode { Integer axis; DataType dtype; TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") { TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue()); - TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue()); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b75e0106f84a..9dcafe5a58cb 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1323,7 +1323,25 @@ def adv_index(inputs): def cumsum(data, axis=None, dtype=None): - """ - TODO + """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axis : int, optional + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are summed. + If dtype is not specified, it defaults to the dtype of data. + + Returns + ------- + result : relay.Expr + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. """ return _make.cumsum(data, axis, dtype) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 2836c1d351a9..0505460dfb47 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -35,15 +35,16 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"): Parameters ---------- data : Buffer - Input data. 2-D Buffer with shape [batch_size, scan_axis_size]. + Input N-D Buffer. Scan is done over the innermost axis. output: Buffer - A buffer to store the output scan, of the same size as data + A buffer to store the output scan, of the same shape as data reduction: Buffer, optional - 1D Buffer of size [batch_size], to store the sum of each row. + (N-1)-D Buffer, to store the sum of each scan axis. - binop : TODO + biop: string, optional + A string specifying which binary operator to use. Currently only "sum" is supported. """ batch_size = prod(data.shape[:-1]) @@ -173,17 +174,18 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop="sum"): Parameters ---------- data : tvm.te.Tensor - Input data. 1-D tensor with shape [scan_axis_size], or - 2-D tensor with shape [batch_size, scan_axis_size]. + Input data of any shape ex_scan_output : tvm.te.Tensor - 1-D tensor that is the exclusive scan of the input, or - 2-D tensor storing the exclusive scan of each row. + The output of exclusive scan on data + + biop: string, optional + A string specifying which binary operator to use. Currently only "sum" is supported. Returns ------- reduction : tvm.te.Tensor - 1-D tensor storing the reduction of each row. + (N-1)-D tensor storing the reduction of each scan axis. """ ndim = len(data.shape) if ndim == 1: @@ -226,7 +228,7 @@ def ir(data, data_ex_scan, reduction): ) reduction = te.extern( - [(data.shape[0],)], + [data.shape[:-1]], [data, ex_scan_output], lambda ins, outs: ir(ins[0], ins[1], outs[0]), dtype=[ex_scan_output.dtype], @@ -247,13 +249,12 @@ def is_thrust_available(): def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, binop="sum"): - """Do exclusive scan on 1D input or along rows of 2D input, using thrust. + """Do exclusive or inclusive scan on 1D or multidimensional input, using thrust. Parameters ---------- data : tvm.te.Tensor - Input data. 1-D tensor with shape [scan_axis_size], or - 2-D tensor with shape [batch_size, scan_axis_size]. + Input data of any shape. The scan is done over the innermost axis. output_dtype: string The dtype of the output scan tensor. @@ -262,18 +263,20 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino Whether or not do exclusive or inclusive scan. return_reduction: bool, optional - Whether or not return a 1-D tensor storing the reduction of each row. + Whether or not return a (N-1)-D tensor storing the reduction of each scan axis. Reductions are computed as part of the upsweep pass, so there is no extra cost. - If False, reductions are ignored. + If False, reductions are ignored. It must be False when exclusive is False. + + biop: string, optional + A string specifying which binary operator to use. Currently only "sum" is supported. Returns ------- output : tvm.te.Tensor - 1-D tensor that is the exclusive scan of the input, or - 2-D tensor storing the exclusive scan of each row. + A N-D tensor of the same rank N and shape as the input data. reduction : tvm.te.Tensor, optional - 1-D tensor storing the reduction of each row. + (N-1)-D tensor storing the reduction of each scan axis. Returned if return_reduction is True. """ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) @@ -301,38 +304,38 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None, binop="sum"): - """Do exclusive scan on 1D input or along rows of 2D input. + """Do exclusive scan on 1D or multidimensional input. Parameters ---------- data : tvm.te.Tensor - Input data. 1-D tensor with shape [scan_axis_size], or - 2-D tensor with shape [batch_size, scan_axis_size]. + Input data of any shape. axis: int, optional - The axis to do scan on. For now, only the inner most axis is supported. + The axis to do scan on. By default, scan is done on the innermost axis. return_reduction: bool, optional - Whether or not return a 1-D tensor storing the reduction of each row. + Whether or not return a tensor storing the reduction over each scan axis. + If the input rank is N, this tensor is of rank N - 1. Reductions are computed as part of the upsweep pass, so there is no extra cost. If False, reductions are ignored. output_dtype: string, optional The dtype of the output scan tensor. If not provided, the dtype of the input is used. - biop: TODO + biop: string, optional + A string specifying which binary operator to use. Currently only "sum" is supported. Returns ------- output : tvm.te.Tensor - 1-D tensor that is the exclusive scan of the input, or - 2-D tensor storing the exclusive scan of each row. + A N-D tensor of the same rank N and shape as the input data. reduction : tvm.te.Tensor, optional - 1-D tensor storing the reduction of each row. + (N-1)-D tensor storing the reduction of each scan axis. Returned if return_reduction is True. """ - # TODO(masahi): Support other binary operators + def do_scan(data, output_dtype): target = tvm.target.Target.current() if target and target.kind.name == "cuda" and is_thrust_available(): @@ -349,7 +352,7 @@ def do_scan(data, output_dtype): if return_reduction: output, reduction = te.extern( - [data.shape, (data.shape[0],)], + [data.shape, data.shape[:-1]], [data], lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], outs[1], binop=binop), dtype=[data.dtype, output_dtype], @@ -387,6 +390,8 @@ def do_scan(data, output_dtype): if axis < 0: axis += ndim + # If scan axis is not the innermost one, swap the scan and the innermost axes + # Scan is always done on the innermost axis, for performance reason. if axis != ndim - 1: axes = swap(list(range(ndim)), axis) data = transpose(data, axes) @@ -407,7 +412,27 @@ def do_scan(data, output_dtype): def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"): - """TODO""" + """Do inclusive scan on 1D or multidimensional input. + + Parameters + ---------- + data : tvm.te.Tensor + Input data of any shape. + + axis: int, optional + The axis to do scan on. By default, scan is done on the innermost axis. + + output_dtype: string, optional + The dtype of the output scan tensor. If not provided, the dtype of the input is used. + + biop: string, optional + A string specifying which binary operator to use. Currently only "sum" is supported. + + Returns + ------- + output : tvm.te.Tensor + A N-D tensor of the same rank N as the input data. + """ ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, binop=binop) if output_dtype is not None and data.dtype != output_dtype and output_dtype != "": @@ -448,7 +473,27 @@ def traverse(op): def cumsum(data, axis=None, dtype=None): - """TODO""" + """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis : int, optional + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are summed. + If dtype is not specified, it defaults to the dtype of data. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ if axis is None: axis = 0 data = reshape(data, (prod(data.shape),)) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index 7df2fcff9aa9..8900c906cbd6 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -23,7 +23,27 @@ def cumsum(data, axis=None, dtype=None): - """TODO""" + """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis : int, optional + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are summed. + If dtype is not specified, it defaults to the dtype of data. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ if dtype is None or dtype == "": dtype = data.dtype diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d18f88f0092a..ff914f4c4776 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3717,6 +3717,8 @@ Expr MakeCumsum(Expr data, Integer axis, DataType dtype) { TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum); RELAY_REGISTER_OP("cumsum") + .describe( + R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py index 5f66d465bf17..c5b6a29d57d5 100644 --- a/tests/python/contrib/test_thrust.py +++ b/tests/python/contrib/test_thrust.py @@ -59,7 +59,7 @@ def test_exclusive_scan(): print("skip because thrust is not enabled...") return - for ishape in [(1,), (10, 10)]: + for ishape in [(10,), (10, 10), (10, 10, 10)]: values = te.placeholder(ishape, name="values", dtype="int32") with tvm.target.Target("cuda"): @@ -75,7 +75,7 @@ def test_exclusive_scan(): if len(ishape) == 1: reduction_shape = () else: - reduction_shape = (ishape[0],) + reduction_shape = ishape[:-1] reduction_np_out = np.zeros(reduction_shape, np.int32) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index a205199ef902..a01a496f92e9 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -27,6 +27,7 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): implementations = { "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -68,3 +69,4 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): if __name__ == "__main__": test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) + test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) From 050d6aa7e350d2de7b6241cbe621d00e9a4a0e91 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 26 Jan 2021 05:29:38 +0900 Subject: [PATCH 17/27] Update python/tvm/topi/cumsum.py Co-authored-by: Tristan Konolige --- python/tvm/topi/cumsum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index 8900c906cbd6..4e3bd84adaa1 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -80,7 +80,7 @@ def gen_ir(data_buf, out_buf): data_buf = ib.buffer_ptr(data_buf) out_buf = ib.buffer_ptr(out_buf) - with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_before, "i") as i: with ib.for_range(0, axis_mul_after) as j: base_idx = i * cumsum_axis_len * axis_mul_after + j out_buf[base_idx] = maybe_cast(data_buf[base_idx]) From a6b71d2f1139fbe2fc7bf050f79251270429b993 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 26 Jan 2021 05:29:59 +0900 Subject: [PATCH 18/27] Update tests/python/relay/test_op_level3.py Co-authored-by: Tristan Konolige --- tests/python/relay/test_op_level3.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index acc8115a55a2..559eb2462fa8 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1343,19 +1343,18 @@ def verify_adv_index(data_shape, index_shapes): verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)]) -@tvm.testing.uses_gpu -def test_cumsum(): +@tvm.testing.parametrize_targets +def test_cumsum(target, ctx): def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e-5): inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype))) out = relay.op.cumsum(inp, axis, out_dtype) func = relay.Function([inp], out) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(data_np) - tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol) + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np) + tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol) data = np.array([2, 3, 0]) verify_cumsum(data, np.cumsum(data)) From 3ca20d85d300f829842c022567207a380fcc4998 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 05:40:17 +0900 Subject: [PATCH 19/27] add example outputs --- python/tvm/relay/op/transform.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9dcafe5a58cb..7e81be6c98b1 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1323,7 +1323,8 @@ def adv_index(inputs): def cumsum(data, axis=None, dtype=None): - """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. + """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along + a given axis. Parameters ---------- @@ -1343,5 +1344,26 @@ def cumsum(data, axis=None, dtype=None): result : relay.Expr The result has the same size as data, and the same shape as data if axis is not None. If axis is None, the result is a 1-d array. + + Examples: + a = [[1,2,3], [4,5,6]] + + cumsum(a) # if axis is not provided, cumsum is done over the flattened input. + -> [ 1, 3, 6, 10, 15, 21] + + cumsum(a, dtype="float32") + -> [ 1., 3., 6., 10., 15., 21.] + + cumsum(a, axis=0) # sum over rows for each of the 3 columns + -> [[1, 2, 3], + [5, 7, 9]] + + cumsum(a, axis=1) + -> [[ 1, 3, 6], + [ 4, 9, 15]] + + a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array + cumsum(a, dtype=int32) # dtype should be provided to get the expected results + -> [1, 1, 2, 2, 3, 4, 4] """ return _make.cumsum(data, axis, dtype) From 0dabe12c1889c7b1448769924cbefbd578c00e46 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 05:51:58 +0900 Subject: [PATCH 20/27] add supported input and output dtype in thrust log --- src/relay/op/tensor/transform.cc | 3 ++- src/runtime/contrib/thrust/thrust.cu | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ff914f4c4776..0e868cdc50c9 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3678,7 +3678,8 @@ TVM_REGISTER_NODE_TYPE(CumsumAttrs); bool CumsumRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); + // types: [data, output] + ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; const auto* data = types[0].as(); if (data == nullptr) { ICHECK(types[0].as()) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index a602dced9391..931649d16aa7 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -340,7 +340,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") } else if (out_dtype == "float64") { thrust_scan(data, output, exclusive); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; } } else if (in_dtype == "int32") { if (out_dtype == "int32") { @@ -352,7 +353,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") } else if (out_dtype == "float64") { thrust_scan(data, output, exclusive); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; } } else if (in_dtype == "int64") { if (out_dtype == "int64") { @@ -362,7 +364,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") } else if (out_dtype == "float64") { thrust_scan(data, output, exclusive); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int64, float32, and float64"; } } else if (in_dtype == "float32") { if (out_dtype == "float32") { @@ -370,16 +373,19 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") } else if (out_dtype == "float64") { thrust_scan(data, output, exclusive); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are float32, and float64"; } } else if (in_dtype == "float64") { if (out_dtype == "float64") { thrust_scan(data, output, exclusive); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtype is float64"; } } else { - LOG(FATAL) << "Unsupported input dtype: " << in_dtype; + LOG(FATAL) << "Unsupported input dtype: " << in_dtype + << ". Supported input dtypes are bool, int32, int64, float32, and float64"; } }); From 8ec333e9a0f1b22570f4d792586dd24d23c7b5a6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 05:53:37 +0900 Subject: [PATCH 21/27] adding more loop var names --- python/tvm/topi/cumsum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index 4e3bd84adaa1..b6939a174274 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -81,10 +81,10 @@ def gen_ir(data_buf, out_buf): out_buf = ib.buffer_ptr(out_buf) with ib.for_range(0, axis_mul_before, "i") as i: - with ib.for_range(0, axis_mul_after) as j: + with ib.for_range(0, axis_mul_after, "j") as j: base_idx = i * cumsum_axis_len * axis_mul_after + j out_buf[base_idx] = maybe_cast(data_buf[base_idx]) - with ib.for_range(0, cumsum_axis_len - 1) as _k: + with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k: k = _k + 1 cur_idx = base_idx + k * axis_mul_after prev_idx = base_idx + (k - 1) * axis_mul_after From 6f481fcad82ade411794723959aba3e8672b5bee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 06:21:42 +0900 Subject: [PATCH 22/27] fix cpplint --- src/runtime/contrib/thrust/thrust.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 931649d16aa7..7295d4c47c3f 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -341,7 +341,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") thrust_scan(data, output, exclusive); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; + << ". Supported output dtypes are int32, int64, float32, and float64"; } } else if (in_dtype == "int32") { if (out_dtype == "int32") { From e3cf96d2266092836f53c96060d6719aa4363ba3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 06:47:14 +0900 Subject: [PATCH 23/27] fix missing check for the cuda target in nms thrust sort --- python/tvm/topi/cuda/nms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 32691da90ecc..2d6e1e464ef8 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -609,7 +609,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): tag="fetch_score", ) - if is_thrust_available(): + target = tvm.target.Target.current() + if target and target.kind.name == "cuda" and is_thrust_available(): sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") else: sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") From fb2f14292e740d2a16dcdb332752841fd7ad0d2b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 10:56:03 +0900 Subject: [PATCH 24/27] parallelize cpu cumsum --- python/tvm/topi/cumsum.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py index b6939a174274..855427b1c619 100644 --- a/python/tvm/topi/cumsum.py +++ b/python/tvm/topi/cumsum.py @@ -80,15 +80,16 @@ def gen_ir(data_buf, out_buf): data_buf = ib.buffer_ptr(data_buf) out_buf = ib.buffer_ptr(out_buf) - with ib.for_range(0, axis_mul_before, "i") as i: - with ib.for_range(0, axis_mul_after, "j") as j: - base_idx = i * cumsum_axis_len * axis_mul_after + j - out_buf[base_idx] = maybe_cast(data_buf[base_idx]) - with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k: - k = _k + 1 - cur_idx = base_idx + k * axis_mul_after - prev_idx = base_idx + (k - 1) * axis_mul_after - out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx]) + with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", kind="parallel") as fused: + i = fused // axis_mul_after + j = fused % axis_mul_after + base_idx = i * cumsum_axis_len * axis_mul_after + j + out_buf[base_idx] = maybe_cast(data_buf[base_idx]) + with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k: + k = _k + 1 + cur_idx = base_idx + k * axis_mul_after + prev_idx = base_idx + (k - 1) * axis_mul_after + out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx]) return ib.get() From 4d8badc8c12e00d6f8441314e4f33e3bc4108236 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 12:54:26 +0900 Subject: [PATCH 25/27] making binop argument tir function --- python/tvm/topi/cuda/scan.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0505460dfb47..64a4818375ea 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -26,10 +26,13 @@ from .injective import schedule_injective_from_existing -binop_name_to_func = {"sum": tvm.tir.generic.add} +def _get_thrust_func_name(tvmop): + tvmop_to_thrust_func_name = {tvm.tir.generic.add: "tvm.contrib.thrust.sum_scan"} + assert tvmop in tvmop_to_thrust_func_name, "{} not supported by thrust".format(tvmop) + return tvmop_to_thrust_func_name[tvmop] -def exclusive_scan_ir(data, output, reduction=None, binop="sum"): +def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): """Low level IR to do exclusive sum scan along rows of 2D input. Parameters @@ -93,7 +96,6 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"): lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" ) - op = binop_name_to_func[binop] with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width @@ -118,7 +120,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.te.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): - output[by * scan_axis_size + end[0] - 1] = op( + output[by * scan_axis_size + end[0] - 1] = binop( output[by * scan_axis_size + end[0] - 1], output[by * scan_axis_size + middle[0] - 1], ) @@ -161,13 +163,13 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"): output[by * scan_axis_size + middle[0] - 1] = output[ by * scan_axis_size + end[0] - 1 ] - output[by * scan_axis_size + end[0] - 1] = op( + output[by * scan_axis_size + end[0] - 1] = binop( output[by * scan_axis_size + end[0] - 1], tmp[0] ) return ib.get() -def get_reduction_from_exclusive_scan(data, ex_scan_output, binop="sum"): +def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generic.add): """Return the sum of the last element of data and the exclusive scan output. The is the reduction of data along each row (for 2-D case). @@ -213,7 +215,7 @@ def ir(data, data_ex_scan, reduction): tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(scan_axis_size > 0): - reduction[tid] = binop_name_to_func[binop]( + reduction[tid] = binop( data_ex_scan[tid * scan_axis_size + scan_axis_size - 1], data[tid, scan_axis_size - 1], ) @@ -248,7 +250,9 @@ def is_thrust_available(): return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None -def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, binop="sum"): +def scan_thrust( + data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add +): """Do exclusive or inclusive scan on 1D or multidimensional input, using thrust. Parameters @@ -281,12 +285,12 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino """ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) - binop_to_thrust_func_name = {"sum": "tvm.contrib.thrust.sum_scan"} + output = te.extern( [data.shape], [data], lambda ins, outs: tvm.tir.call_packed( - binop_to_thrust_func_name[binop], ins[0], outs[0], exclusive + _get_thrust_func_name(binop), ins[0], outs[0], exclusive ), dtype=[output_dtype], in_buffers=[data_buf], @@ -303,7 +307,9 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino return output -def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None, binop="sum"): +def exclusive_scan( + data, axis=-1, return_reduction=False, output_dtype=None, binop=tvm.tir.generic.add +): """Do exclusive scan on 1D or multidimensional input. Parameters @@ -411,7 +417,7 @@ def do_scan(data, output_dtype): return output -def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"): +def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add): """Do inclusive scan on 1D or multidimensional input. Parameters @@ -438,7 +444,7 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"): if output_dtype is not None and data.dtype != output_dtype and output_dtype != "": data = cast(data, output_dtype) - return binop_name_to_func[binop](data, ex_scan) + return binop(data, ex_scan) def schedule_scan(outs): @@ -498,4 +504,4 @@ def cumsum(data, axis=None, dtype=None): axis = 0 data = reshape(data, (prod(data.shape),)) axis = get_const_int(axis) - return inclusive_scan(data, axis, output_dtype=dtype, binop="sum") + return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add) From af58e07a4ced81c96345fd18c347aa654b6db716 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Jan 2021 13:04:49 +0900 Subject: [PATCH 26/27] update doc for binop --- python/tvm/topi/cuda/scan.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 64a4818375ea..a34e89677ceb 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -46,8 +46,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): reduction: Buffer, optional (N-1)-D Buffer, to store the sum of each scan axis. - biop: string, optional - A string specifying which binary operator to use. Currently only "sum" is supported. + binop: function, optional + A binary associative op to use for scan. The function takes two TIR expressions + and produce a new TIR expression. """ batch_size = prod(data.shape[:-1]) @@ -181,8 +182,9 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi ex_scan_output : tvm.te.Tensor The output of exclusive scan on data - biop: string, optional - A string specifying which binary operator to use. Currently only "sum" is supported. + binop: function, optional + A binary associative op to use for scan. The function takes two TIR expressions + and produce a new TIR expression. Returns ------- @@ -271,8 +273,10 @@ def scan_thrust( Reductions are computed as part of the upsweep pass, so there is no extra cost. If False, reductions are ignored. It must be False when exclusive is False. - biop: string, optional - A string specifying which binary operator to use. Currently only "sum" is supported. + binop: function, optional + A binary associative op to use for scan. Since we need to lookup the corresponding + thrust function, arbitrariy callables are not supported. Currently only + tvm.tir.generic.add can be passed in. Returns ------- @@ -329,8 +333,9 @@ def exclusive_scan( output_dtype: string, optional The dtype of the output scan tensor. If not provided, the dtype of the input is used. - biop: string, optional - A string specifying which binary operator to use. Currently only "sum" is supported. + binop: function, optional + A binary associative op to use for scan. The function takes two TIR expressions + and produce a new TIR expression. Returns ------- @@ -431,8 +436,9 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add): output_dtype: string, optional The dtype of the output scan tensor. If not provided, the dtype of the input is used. - biop: string, optional - A string specifying which binary operator to use. Currently only "sum" is supported. + binop: function, optional + A binary associative op to use for scan. The function takes two TIR expressions + and produce a new TIR expression. Returns ------- From 91196272a4a7a7f072c4f02b9534c7fa584e86de Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 05:22:50 +0900 Subject: [PATCH 27/27] doc update --- python/tvm/relay/op/transform.py | 4 +++- python/tvm/topi/cuda/scan.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7e81be6c98b1..6785ff248612 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1345,7 +1345,9 @@ def cumsum(data, axis=None, dtype=None): The result has the same size as data, and the same shape as data if axis is not None. If axis is None, the result is a 1-d array. - Examples: + Examples + -------- + .. code-block:: python a = [[1,2,3], [4,5,6]] cumsum(a) # if axis is not provided, cumsum is done over the flattened input. diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index a34e89677ceb..232d679840fd 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -48,7 +48,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. + and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + prefix sum. """ batch_size = prod(data.shape[:-1]) @@ -184,7 +185,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. + and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + prefix sum. Returns ------- @@ -335,7 +337,8 @@ def exclusive_scan( binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. + and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + prefix sum. Returns ------- @@ -438,7 +441,8 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add): binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. + and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + prefix sum. Returns -------