From 40012fb530fcda1b863b5e66c01ccacefeb11024 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Mon, 25 Oct 2021 09:18:29 +0000 Subject: [PATCH 1/4] add activations and unary operators --- python/tvm/relay/frontend/paddlepaddle.py | 158 ++++++++++++++++++---- 1 file changed, 131 insertions(+), 27 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index ef361d6c55e8..aa8c0c72b097 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -827,38 +827,120 @@ def convert_shape(g, op, block): def convert_slice(g, op, block): """Operator converter for slice.""" - def parameter_process(starts, ends, axes, dshape): - new_axes = [] - new_starts = [] - new_ends = [] - pop_index = 0 - for i in range(max(axes) + 1): - new_axes.append(i) - if i in axes: - new_starts.append(starts[pop_index]) - new_ends.append(ends[pop_index]) - pop_index += 1 - else: - new_starts.append(0) - new_ends.append(dshape[i]) - return new_starts, new_ends, new_axes - data = g.get_node(op.input("Input")[0]) - dshape = infer_shape(data) - starts = op.attr("starts") - ends = op.attr("ends") + dims = len(infer_shape(data)) + axes = op.attr("axes") + indices = _expr.const(axes, dtype="int64") + decrease_axis = op.attr("decrease_axis") - if isinstance(starts, int): - starts = [starts] - if isinstance(ends, int): - ends = [ends] - if isinstance(axes, int): - axes = [axes] if isinstance(decrease_axis, int): decrease_axis = [decrease_axis] - starts, ends, axes = parameter_process(starts, ends, axes, dshape) - out = _op.strided_slice(data, begin=starts, end=ends) + + if op.input("StartsTensor"): + # if `starts` is a tensor + starts = g.get_node(op.input("StartsTensor")[0]) + starts = _infer_value(starts, g.get_params()) + elif op.input("StartsTensorList"): + # if `starts` is a list of tensor + starts = [] + for start_index in op.input("StartsTensorList"): + start_index = g.get_node(start_index).astype("int64") + starts.append(start_index) + starts = _op.concatenate(starts, axis=0) + starts = _infer_value(starts, g.get_params()) + else: + # if `starts` is constant value + starts = op.attr("starts") + + if len(axes) < dims: + # make the numel of `starts` be same with the rank of input tensor + if isinstance(starts, _expr.Expr): + starts = _op.scatter( + _op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype), + indices, + starts, + axis=0, + ) + else: + base = [0] * dims + for i, axis in enumerate(axes): + base[axis] = starts[i] + starts = base + + if op.input("EndsTensor"): + # if `ends` is a tensor + ends = g.get_node(op.input("EndsTensor")[0]) + ends = _infer_value(ends, g.get_params()) + elif op.input("EndsTensorList"): + # if `ends` is a list of tensor + ends = [] + for end_index in op.input("EndsTensorList"): + end_index = g.get_node(end_index).astype("int64") + ends.append(end_index) + ends = _op.concatenate(ends, axis=0) + ends = _infer_value(ends, g.get_params()) + else: + # if `ends` is constant value + ends = op.attr("ends") + + if len(axes) < dims: + # make the numel of `ends` be same with the rank of input tensor + if isinstance(ends, _expr.Expr): + ends = _op.scatter( + _expr.const( + np.array([np.iinfo(np.int32).max] * dims), + dtype=infer_type(ends).checked_type.dtype, + ), + indices, + ends, + axis=0, + ) + else: + base = [np.iinfo(np.int32).max] * dims + for i, axis in enumerate(axes): + base[axis] = ends[i] + ends = base + + strides = None + if "StridesTensor" in op.input_names and op.input("StridesTensor"): + # if `strides` is a input tensor + strides = g.get_node(op.input("StridesTensor")[0]) + strides = _infer_value(strides, g.get_params()) + elif "StridesTensorList" in op.input_names and op.input("StridesTensorList"): + # if `strides` is a list of tensor + strides = [] + for strides_index in op.input("StridesTensorList"): + strides_index = g.get_node(strides_index).astype("int64") + strides.append(strides_index) + strides = _op.concatenate(strides, axis=0) + strides = _infer_value(strides, g.get_params()) + elif op.has_attr("strides"): + # if `strides` is constant value + strides = op.attr("strides") + else: + # default value for `strides` + strides = [1] * dims + + if len(axes) < dims: + # make the numel of `strides` be same with the rank of input tensor + if isinstance(strides, _expr.Expr): + strides = _op.scatter( + _expr.const( + np.array([1] * dims), + dtype=infer_type(strides).checked_type.dtype, + ), + indices, + strides, + axis=0, + ) + else: + base = [1] * dims + for i, axis in enumerate(axes): + base[axis] = strides[i] + strides = base + + out = _op.strided_slice(data, begin=starts, end=ends, strides=strides) if decrease_axis: out = _op.squeeze(out, axis=decrease_axis) g.add_node(op.output("Out")[0], out) @@ -900,15 +982,22 @@ def convert_unsqueeze(g, op, block): _convert_map = { + "abs": convert_unary_op, + "acos": convert_unary_op, "arg_max": convert_arg_max_min, "arg_min": convert_arg_max_min, "argsort": convert_argsort, + "asin": convert_unary_op, "assign": convert_assign, "assign_value": convert_assign_value, + "atan": convert_unary_op, "batch_norm": convert_batch_norm, "cast": convert_cast, + "ceil": convert_unary_op, "concat": convert_concat, "conv2d": convert_conv2d, + "cos": convert_unary_op, + "cosh": convert_unary_op, "cumsum": convert_cumsum, "depthwise_conv2d": convert_conv2d, "dot": convert_dot, @@ -918,12 +1007,14 @@ def convert_unsqueeze(g, op, block): "elementwise_mul": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, "equal": convert_elementwise_op, + "erf": convert_unary_op, "exp": convert_unary_op, "expand_v2": convert_expand, "expand_as_v2": convert_expand_as, "feed": convert_feed, "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, + "floor": convert_unary_op, "gelu": convert_gelu, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, @@ -932,6 +1023,11 @@ def convert_unsqueeze(g, op, block): "isnan_v2": convert_unary_op, "layer_norm": convert_layer_norm, "leaky_relu": convert_leaky_relu, + "less_equal": convert_elementwise_op, + "less_than": convert_elementwise_op, + "log": convert_unary_op, + "log2": convert_unary_op, + "log10": convert_unary_op, "logical_and": convert_binary_logical_op, "logical_or": convert_binary_logical_op, "logical_xor": convert_binary_logical_op, @@ -943,11 +1039,19 @@ def convert_unsqueeze(g, op, block): "pool2d": convert_pool2d, "relu": convert_unary_op, "reshape2": convert_reshape, + "round": convert_unary_op, + "rsqrt": convert_unary_op, "scale": convert_scale, "shape": convert_shape, + "sigmoid": convert_unary_op, + "sign": convert_unary_op, + "sin": convert_unary_op, + "sinh": convert_unary_op, "slice": convert_slice, "softmax": convert_softmax, + "sqrt": convert_unary_op, "squeeze2": convert_squeeze, + "tan": convert_unary_op, "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } From 8691ade3573c478295ff28f8951e3a2175b8212d Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 26 Oct 2021 07:27:24 +0000 Subject: [PATCH 2/4] revert modify of slice --- python/tvm/relay/frontend/paddlepaddle.py | 136 +++++----------------- 1 file changed, 27 insertions(+), 109 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index aa8c0c72b097..0f4edb602e12 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -827,120 +827,38 @@ def convert_shape(g, op, block): def convert_slice(g, op, block): """Operator converter for slice.""" - data = g.get_node(op.input("Input")[0]) - dims = len(infer_shape(data)) + def parameter_process(starts, ends, axes, dshape): + new_axes = [] + new_starts = [] + new_ends = [] + pop_index = 0 + for i in range(max(axes) + 1): + new_axes.append(i) + if i in axes: + new_starts.append(starts[pop_index]) + new_ends.append(ends[pop_index]) + pop_index += 1 + else: + new_starts.append(0) + new_ends.append(dshape[i]) + return new_starts, new_ends, new_axes + data = g.get_node(op.input("Input")[0]) + dshape = infer_shape(data) + starts = op.attr("starts") + ends = op.attr("ends") axes = op.attr("axes") - indices = _expr.const(axes, dtype="int64") - decrease_axis = op.attr("decrease_axis") + if isinstance(starts, int): + starts = [starts] + if isinstance(ends, int): + ends = [ends] + if isinstance(axes, int): + axes = [axes] if isinstance(decrease_axis, int): decrease_axis = [decrease_axis] - - if op.input("StartsTensor"): - # if `starts` is a tensor - starts = g.get_node(op.input("StartsTensor")[0]) - starts = _infer_value(starts, g.get_params()) - elif op.input("StartsTensorList"): - # if `starts` is a list of tensor - starts = [] - for start_index in op.input("StartsTensorList"): - start_index = g.get_node(start_index).astype("int64") - starts.append(start_index) - starts = _op.concatenate(starts, axis=0) - starts = _infer_value(starts, g.get_params()) - else: - # if `starts` is constant value - starts = op.attr("starts") - - if len(axes) < dims: - # make the numel of `starts` be same with the rank of input tensor - if isinstance(starts, _expr.Expr): - starts = _op.scatter( - _op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype), - indices, - starts, - axis=0, - ) - else: - base = [0] * dims - for i, axis in enumerate(axes): - base[axis] = starts[i] - starts = base - - if op.input("EndsTensor"): - # if `ends` is a tensor - ends = g.get_node(op.input("EndsTensor")[0]) - ends = _infer_value(ends, g.get_params()) - elif op.input("EndsTensorList"): - # if `ends` is a list of tensor - ends = [] - for end_index in op.input("EndsTensorList"): - end_index = g.get_node(end_index).astype("int64") - ends.append(end_index) - ends = _op.concatenate(ends, axis=0) - ends = _infer_value(ends, g.get_params()) - else: - # if `ends` is constant value - ends = op.attr("ends") - - if len(axes) < dims: - # make the numel of `ends` be same with the rank of input tensor - if isinstance(ends, _expr.Expr): - ends = _op.scatter( - _expr.const( - np.array([np.iinfo(np.int32).max] * dims), - dtype=infer_type(ends).checked_type.dtype, - ), - indices, - ends, - axis=0, - ) - else: - base = [np.iinfo(np.int32).max] * dims - for i, axis in enumerate(axes): - base[axis] = ends[i] - ends = base - - strides = None - if "StridesTensor" in op.input_names and op.input("StridesTensor"): - # if `strides` is a input tensor - strides = g.get_node(op.input("StridesTensor")[0]) - strides = _infer_value(strides, g.get_params()) - elif "StridesTensorList" in op.input_names and op.input("StridesTensorList"): - # if `strides` is a list of tensor - strides = [] - for strides_index in op.input("StridesTensorList"): - strides_index = g.get_node(strides_index).astype("int64") - strides.append(strides_index) - strides = _op.concatenate(strides, axis=0) - strides = _infer_value(strides, g.get_params()) - elif op.has_attr("strides"): - # if `strides` is constant value - strides = op.attr("strides") - else: - # default value for `strides` - strides = [1] * dims - - if len(axes) < dims: - # make the numel of `strides` be same with the rank of input tensor - if isinstance(strides, _expr.Expr): - strides = _op.scatter( - _expr.const( - np.array([1] * dims), - dtype=infer_type(strides).checked_type.dtype, - ), - indices, - strides, - axis=0, - ) - else: - base = [1] * dims - for i, axis in enumerate(axes): - base[axis] = strides[i] - strides = base - - out = _op.strided_slice(data, begin=starts, end=ends, strides=strides) + starts, ends, axes = parameter_process(starts, ends, axes, dshape) + out = _op.strided_slice(data, begin=starts, end=ends) if decrease_axis: out = _op.squeeze(out, axis=decrease_axis) g.add_node(op.output("Out")[0], out) From 5258481d2b828a3ccef2e4905027073c3ed46d41 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 26 Oct 2021 08:13:26 +0000 Subject: [PATCH 3/4] add test cases --- python/tvm/relay/frontend/paddlepaddle.py | 34 +++++++++++++ .../frontend/paddlepaddle/test_forward.py | 49 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 0f4edb602e12..56d337c0b51d 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -763,6 +763,33 @@ def convert_pool2d(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_reduce(g, op, block): + """Operator converter for series of reduce operators.""" + + op_map = { + "reduce_all": "all", + "reduce_any": "any", + "reduce_max": "max", + "reduce_min": "min", + "reduce_prod": "prod", + "reduce_sum": "sum", + "reduce_mean": "mean", + } + op_name = op_map[op.type] + input_x = g.get_node(op.input("X")[0]) + axis = op.attr("dim") + if op.attr("reduce_all"): + axis = None + keepdims = op.attr("keep_dim") + out = get_relay_op(op_name)(input_x, axis=axis, keepdims=keepdims) + if not axis and not keepdims: + # use `expand_dims` to solve the following situation + # for TVM, the shape of `out` will be (, ) + # for Paddle, the shape of `out` will be [1] + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + def convert_reshape(g, op, block): """Operator converter for reshape.""" @@ -958,6 +985,13 @@ def convert_unsqueeze(g, op, block): "relu": convert_unary_op, "reshape2": convert_reshape, "round": convert_unary_op, + "reduce_all": convert_reduce, + "reduce_any": convert_reduce, + "reduce_max": convert_reduce, + "reduce_min": convert_reduce, + "reduce_prod": convert_reduce, + "reduce_sum": convert_reduce, + "reduce_mean": convert_reduce, "rsqrt": convert_unary_op, "scale": convert_scale, "shape": convert_shape, diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index b274d178c9c2..513c4b206c4d 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -461,6 +461,8 @@ def forward(self, input1, input2): api_list = [ "equal", + "less_equal", + "less_than", ] x_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] y_shapes = [[1], [8, 20], [4, 1, 1], [2, 3, 8, 8], [2, 3, 3, 9, 1]] @@ -799,6 +801,33 @@ def forward(self, inputs): verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_reduce(): + class Reduce(nn.Layer): + def __init__(self, op_name, axis=None, keepdim=False): + super(Reduce, self).__init__() + self.op_name = op_name + self.axis = axis + self.keepdim = keepdim + + @paddle.jit.to_static + def forward(self, inputs): + result = getattr(paddle, self.op_name)(inputs, axis=self.axis, keepdim=self.keepdim) + result = result.astype("float32") + return result + + input_shapes = [[1, 2, 2, 5, 5], [2, 3, 4], [4, 20], [2, 3, 30, 30]] + for input_shape in input_shapes: + input_data = paddle.uniform(min=-3, max=3, shape=input_shape, dtype="float32") + verify_model(Reduce("all"), input_data=input_data.astype("bool")) + verify_model(Reduce("any", 1), input_data=input_data.astype("bool")) + verify_model(Reduce("max", 0, True), input_data=input_data) + verify_model(Reduce("min", 1, True), input_data=input_data) + verify_model(Reduce("prod", 0), input_data=input_data) + verify_model(Reduce("sum", 0, True), input_data=input_data) + verify_model(Reduce("mean", -1, True), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_reshape(): @paddle.jit.to_static @@ -899,8 +928,28 @@ def forward(self, inputs): return self.func(inputs) api_list = [ + "abs", + "acos", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", "exp", + "floor", + "log", + "log2", + "log10", "relu", + "round", + "rsqrt", + "sigmoid", + "sign", + "sin", + "sinh", + "sqrt", + "tan", "tanh", ] input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]] From 3ddd37c768f40d1188b59200027e15b844fa1fa7 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Thu, 28 Oct 2021 02:55:43 +0000 Subject: [PATCH 4/4] disable signal capturing in paddle framework --- python/tvm/relay/frontend/paddlepaddle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 56d337c0b51d..a568be469e98 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1179,6 +1179,10 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): import paddle + # disable system signal capturing in paddle framework + # the signal capturing may cause conflict while running autotvm with paddle frontend + paddle.disable_signal_handler() + g = GraphProto() if isinstance(program_or_layer, paddle.jit.TranslatedLayer): # model is loaded by `paddle.jit.load`