diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 76a12691d2bf..0aaa3c434edb 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -25,12 +25,15 @@ from tvm.ir import IRModule from .. import analysis +from .. import ty as _ty from .. import expr as _expr +from ..loops import while_loop from .. import function as _function from .. import ty as _ty from .. import op as _op from .common import ( fold_constant, + get_relay_op, infer_shape, infer_type, infer_value, @@ -40,14 +43,76 @@ __all__ = ["from_paddle"] -def shape_of(x, dtype="int32"): - """Get shape of a tensor""" +class ControlFlow: + """Control flow converter for PaddlePaddle.""" - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(shape, dtype) - return _op.shape_of(x, dtype) + operators = [ + "while", + ] + + @classmethod + def convert_block(cls, graph, block): + for op in block.ops: + if op.type in ControlFlow.operators: + raise Exception("Nested Control Flow Not Support Yet.") + convert_func = _convert_map[op.type] + convert_func(graph, op, block) + + @classmethod + def convert(cls, graph, op, program): + func = getattr(cls, "convert_{}".format(op.type)) + return func(graph, op, program) + + @classmethod + def convert_while(cls, graph, op, program): + """Operator converter for while.""" + + sub_block_id = op.attr("sub_block").id + sub_block = program.blocks[sub_block_id] + input_names = op.input("X") + output_names = op.output("Out") + cond_name = op.input("Condition")[0] + + for name in output_names: + if name == cond_name: + continue + if name not in input_names: + raise Exception("Output '{}' not in inputs".format(name)) + + sub_graph = GraphProto(graph.freeze_params) + sub_graph.set_params(graph.get_params()) + cond_var = _expr.var(cond_name, shape=[1], dtype="bool") + loop_vars = list() + loop_vars.append(cond_var) + for i, name in enumerate(op.input("X")): + shape = infer_shape(graph.get_node(name)) + dtype = program.blocks[0].var(name).dtype + dtype = str(dtype).strip().split(".")[1] + var = _expr.var(name, shape=shape, dtype=dtype) + loop_vars.append(var) + + def cond_fn(*loop_inputs): + squeezed_cond = _op.squeeze(loop_inputs[0]) + return _op.equal(squeezed_cond, _expr.const(True, "bool")) + + def body_fn(*loop_inputs): + body_inputs = loop_inputs[1:] + for i, ipt in enumerate(body_inputs): + sub_graph.add_node(input_names[i], ipt) + cls.convert_block(sub_graph, sub_block) + sub_outputs = [sub_graph.get_node(cond_name)] + sub_outputs += [sub_graph.get_node(name) for name in input_names] + return sub_outputs + + loop = while_loop(cond_fn, loop_vars, body_fn) + + init_cond = graph.get_node(op.input("Condition")[0]) + init_inputs = [graph.get_node(x) for x in op.input("X")] + init_loop_vars = init_inputs + + loop_vals = loop(init_cond, *init_loop_vars) + for i, name in enumerate(input_names): + graph.add_node(name, _expr.TupleGetItem(loop_vals, i + 1)) def _get_pad_size(in_size, dilated_kernel_size, stride_size): @@ -64,12 +129,128 @@ def _get_pad_size(in_size, dilated_kernel_size, stride_size): return [pad_before, pad_after] +def _dtype_shape_promotion(inputs): + """promote data type and shape for list of tensors.""" + + dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] + + ranks = [len(infer_shape(x)) for x in inputs] + if set(ranks) == set([1, 0]): + for i, r in enumerate(ranks): + if r == 0: + inputs[i] = _op.expand_dims(inputs[i], axis=0) + + dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs) + if len(dtypes) == 1: + return inputs + max_dtype = dtype_order[max(dtypes)] + for i, input_op in enumerate(inputs): + if infer_type(input_op).checked_type.dtype != max_dtype: + inputs[i] = input_op.astype(max_dtype) + return inputs + + +def shape_of(x, dtype="int32"): + """Get shape of a tensor""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(np.array(shape), dtype) + return _op.shape_of(x, dtype) + + +def _infer_value(x, params): + """Try running infer_value, and if successful, return the inferred value. + Otherwise, return input""" + + try: + value = infer_value(x, params) + return value.numpy().tolist() + except Exception: # pylint: disable=broad-except + return x + + +def _convert_dtype_value(val): + """converts a Paddle type id to a string.""" + + convert_dtype_map = { + 21: "int8", + 20: "uint8", + 6: "float64", + 5: "float32", + 4: "float16", + 3: "int64", + 2: "int32", + 1: "int16", + 0: "bool", + } + if val not in convert_dtype_map: + msg = "Paddle data type value %d is not handled yet." % (val) + raise NotImplementedError(msg) + return convert_dtype_map[val] + + +def convert_unary_op(g, op, block): + """Operator converter for all the activation.""" + + op_map = { + "isinf_v2": _op.isinf, + "isfinite_v2": _op.isfinite, + "isnan_v2": _op.isnan, + } + if op.type in op_map: + unary_func = op_map[op.type] + else: + unary_func = get_relay_op(op.type) + out = unary_func(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) + + +def convert_addmm(g, op, block): + """Operator converter for addmm.""" + + input_x = g.get_node(op.input("Input")[0]) + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + alpha = op.attr("Alpha") + beta = op.attr("Beta") + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + + if not isinstance(alpha, _expr.Expr) and alpha != 1: + alpha = _expr.const(alpha, dtype) + x *= alpha + + if not isinstance(beta, _expr.Expr) and beta != 1: + beta = _expr.const(beta, dtype) + input_x *= beta + + transposed_y = _op.transpose(y, axes=[1, 0]) + dense_out = _op.nn.dense(x, transposed_y) + out = dense_out + input_x + g.add_node(op.output("Out")[0], out) + + +def convert_addn(g, op, block): + """Operator converter for sum(add_n).""" + + inputs = op.input("X") + out = g.get_node(inputs[0]) + for i in range(1, len(inputs)): + out += g.get_node(inputs[i]) + g.add_node(op.output("Out")[0], out) + + def convert_arg_max(g, op, block): """Operator converter for arg_max.""" axis = op.attr("axis") keepdims = op.attr("keepdims") flatten = op.attr("flatten") + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) if axis is None or flatten: @@ -77,13 +258,64 @@ def convert_arg_max(g, op, block): out = _op.argmax(x, axis=None, keepdims=True) else: out = _op.argmax(x, axis=axis, keepdims=keepdims) + if dtype != infer_type(out).checked_type.dtype: + out = _op.cast(out, dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_arg_min(g, op, block): + """Operator converter for arg_min.""" + + axis = op.attr("axis") + keepdims = op.attr("keepdims") + flatten = op.attr("flatten") + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) + + x = g.get_node(op.input("X")[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + out = _op.argmin(x, axis=None, keepdims=True) + else: + out = _op.argmin(x, axis=axis, keepdims=keepdims) + if dtype != infer_type(out).checked_type.dtype: + out = _op.cast(out, dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_argsort(g, op, block): + """Operator converter for argsort.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + descending = op.attr("descending") + + out = _op.sort(x, axis, not descending) + out_indice = _op.argsort(x, axis, not descending, dtype="int64") g.add_node(op.output("Out")[0], out) + g.add_node(op.output("Indices")[0], out_indice) def convert_assign(g, op, block): """Operator converter for assign.""" - out = _op.copy(g.get_node(op.input("X")[0])) + out = g.get_node(op.input("X")[0]) + g.add_node(op.output("Out")[0], out) + + +def convert_assign_value(g, op, block): + """Operator converter for assign_value.""" + + keys = ["bool_values", "fp32_values", "int32_values", "int64_values"] + dtypes = ["bool", "float32", "int32", "int64"] + for i, key in enumerate(keys): + dtype = dtypes[i] + value = np.array(op.attr(key)).astype(dtype) + if value is not None and value.size >= 1: + break + shape = op.attr("shape") + value = value.reshape(shape) + out = _op.const(value, dtype=dtype) g.add_node(op.output("Out")[0], out) @@ -107,21 +339,169 @@ def convert_batch_norm(g, op, block): g.add_node(op.output("Y")[0], out[0]) +def convert_bmm(g, op, block): + """Operator converter for bmm.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + y = _op.transpose(y, [0, 2, 1]) + out = _op.nn.batch_matmul(x, y) + g.add_node(op.output("Out")[0], out) + + +def convert_interpolate2d(g, op, x): + """Operator converter for interpolate 2D(dims == 4).""" + + def get_interpolate_mode(op): + """conver 'interp_method' attr of paddle to tvm""" + + interp_method = op.attr("interp_method") + align_corners = op.attr("align_corners") + align_mode = op.attr("align_mode") + + rounding_method = "" + if interp_method == "nearest": + interp_method = "nearest_neighbor" + coordinate_transformation_mode = "asymmetric" + rounding_method = "floor" + elif interp_method == "bilinear": + interp_method = "linear" + if not align_corners and align_mode == 0: + coordinate_transformation_mode = "half_pixel" + else: + if align_corners: + coordinate_transformation_mode = "align_corners" + else: + coordinate_transformation_mode = "asymmetric" + elif interp_method == "bicubic": + interp_method = "cubic" + if align_corners: + coordinate_transformation_mode = "align_corners" + else: + coordinate_transformation_mode = "half_pixel" + else: + msg = "interp_method {} is not supported for PaddlePaddle's interpolate" + raise tvm.error.OpAttributeInvalid(msg.format(interp_method)) + return rounding_method, interp_method, coordinate_transformation_mode + + layout = op.attr("data_layout") + out_h = op.attr("out_h") + out_w = op.attr("out_w") + out_size = [out_h, out_w] + + input_out_size = op.input("OutSize") + input_size_tensor = op.input("SizeTensor") + input_scale = op.input("Scale") + if input_size_tensor: + out_size = g.get_node(input_size_tensor[0]) + out_size = _infer_value(out_size, g.get_params()) + elif input_out_size: + out_size = g.get_node(input_out_size[0]) + out_size = _infer_value(out_size, g.get_params()) + else: + input_shape = infer_shape(x) + if layout == "NCHW": + in_h, in_w = input_shape[2], input_shape[3] + else: + in_h, in_w = input_shape[1], input_shape[2] + if input_scale: + scale_data = g.get_node(input_scale[0]) + scale_data = infer_value(scale_data, g.get_params()).numpy().tolist() + if len(scale_data) > 1: + out_h = int(scale_data[0] * in_h) + out_w = int(scale_data[1] * in_w) + else: + out_h = int(scale_data[0] * in_h) + out_w = int(scale_data[0] * in_w) + out_size = [out_h, out_w] + else: + scale = op.attr("scale") + scale = [float(i) for i in scale] + if len(scale) > 1: + out_h = int(scale[0] * in_h) + out_w = int(scale[1] * in_w) + out_size = [out_h, out_w] + + rounding_method, interp_method, coordinate_transformation_mode = get_interpolate_mode(op) + out = _op.image.resize2d( + x, + size=out_size, + layout=layout, + method=interp_method, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + cubic_alpha=-0.75, + ) + g.add_node(op.output("Out")[0], out) + + +def convert_interpolate(g, op, block): + """Operator converter for interpolate.""" + + x = g.get_node(op.input("X")[0]) + layout = op.attr("data_layout") + if layout in ("NCHW", "NHWC"): + convert_interpolate2d(g, op, x) + else: + msg = "layout {} is not supported for PaddlePaddle's interpolate" + raise tvm.error.OpAttributeInvalid(msg.format(layout)) + + def convert_cast(g, op, block): """Operator converter for cast.""" - dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] + dtype = op.attr("out_dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) out = _op.cast(x, dtype=dtype) g.add_node(op.output("Out")[0], out) +def convert_clip(g, op, block): + """Operator converter for clip.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + is_dynamic = False + if op.input("Min"): + min_value = g.get_node(op.input("Min")[0]) + min_value = _infer_value(min_value, g.get_params()) + if isinstance(min_value, _expr.Expr): + is_dynamic = True + else: + min_value = min_value[0] + else: + min_value = op.attr("min") + if op.input("Max"): + max_value = g.get_node(op.input("Max")[0]) + max_value = _infer_value(max_value, g.get_params()) + if isinstance(max_value, _expr.Expr): + if not is_dynamic: + is_dynamic = True + min_value = _op.const(min_value, dtype) + else: + max_value = max_value[0] + if is_dynamic: + max_value = _op.const(max_value, dtype) + else: + max_value = op.attr("max") + if is_dynamic: + max_value = _op.const(max_value, dtype) + + if not is_dynamic: + out = _op.clip(x, min_value, max_value) + else: + out = _op.maximum(x, min_value) + out = _op.minimum(out, max_value) + g.add_node(op.output("Out")[0], out) + + def convert_concat(g, op, block): """Operator converter for concat.""" inputs = [g.get_node(op.input("X")[i]) for i in range(len(op.input("X")))] axis = op.attr("axis") + inputs = _dtype_shape_promotion(inputs) out = _op.concatenate(inputs, axis=axis) g.add_node(op.output("Out")[0], out) @@ -138,12 +518,22 @@ def convert_conv2d(g, op, block): kernel = g.get_node(op.input("Filter")[0]) input_x = g.get_node(op.input("Input")[0]) out_channels, _, k_h, k_w = infer_shape(kernel) - in_h, in_w = infer_shape(input_x)[2:] if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) + if strides[0] == 1 and strides[1] == 1: + pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) + else: + input_shape = shape_of(input_x) + h_w = _op.strided_slice(input_shape, [2], [4]) + try: + in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() + except Exception as e: + msg = "The SAME padding algorithm of Conv not support dynamic shape" + raise tvm.error.OpAttributeInvalid(msg) from e + pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -167,6 +557,90 @@ def convert_conv2d(g, op, block): g.add_node(op.output("Output")[0], out) +def convert_conv2d_transpose(g, op, block): + """Operator converter for conv2d_transpose.""" + + dilations = op.attr("dilations") + groups = op.attr("groups") + paddings = op.attr("paddings") + padding_algorithm = op.attr("padding_algorithm") + strides = op.attr("strides") + output_padding = op.attr("output_padding") if op.attr("output_padding") else [0, 0] + + kernel = g.get_node(op.input("Filter")[0]) + input_x = g.get_node(op.input("Input")[0]) + _, out_channels, k_h, k_w = infer_shape(kernel) + if padding_algorithm == "VALID": + paddings = [0, 0] + elif padding_algorithm == "SAME": + if strides[0] == 1 and strides[1] == 1: + pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) + else: + input_shape = shape_of(input_x) + h_w = _op.strided_slice(input_shape, [2], [4]) + try: + in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() + except Exception as e: + msg = "The SAME padding algorithm of Conv_Transpose not support dynamic shape" + raise tvm.error.OpAttributeInvalid(msg) from e + pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) + paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + elif padding_algorithm == "EXPLICIT": + if len(paddings) == 2: + paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] + if len(paddings) == 4: + paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] + else: + msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + + out = _op.nn.conv2d_transpose( + input_x, + kernel, + strides=strides, + padding=paddings, + dilation=dilations, + groups=groups, + channels=out_channels, + kernel_size=[k_h, k_w], + output_padding=output_padding, + ) + g.add_node(op.output("Output")[0], out) + + +def convert_crop(g, op, block): + """Operator converter for crop.""" + + x = g.get_node(op.input("X")[0]) + dims = len(infer_shape(x)) + input_shape = op.input("Shape") + input_offsets = op.input("Offsets") + if input_shape: + shape = g.get_node(input_shape[0]) + shape = _infer_value(shape, g.get_params()) + else: + shape = op.attr("shape") + + if input_offsets: + offsets = g.get_node(input_offsets[0]) + offsets = _infer_value(offsets, g.get_params()) + else: + offsets = op.attr("offsets") + + if not isinstance(shape, _expr.Expr): + shape = _op.const(shape, "int32") + if not isinstance(offsets, _expr.Expr): + offsets = _op.const(offsets, "int32") + slice_start = offsets + slice_end = _op.add(shape, offsets) + strides = _op.const([1] * dims, dtype="int32") + + out = _op.strided_slice(x, slice_start, slice_end, strides) + g.add_node(op.output("Out")[0], out) + + def convert_cumsum(g, op, block): """Operator converter for cumsum.""" @@ -191,7 +665,51 @@ def convert_dropout(g, op, block): """Operator converter for dropout.""" x = g.get_node(op.input("X")[0]) - out = _op.copy(x) + g.add_node(op.output("Out")[0], x) + + +def convert_elu(g, op, block): + """Operator converter for elu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + alpha = op.attr("alpha") + alpha = _expr.const(-1.0 * alpha, dtype=dtype) + out = alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(x)) + _op.nn.relu(x) + g.add_node(op.output("Out")[0], out) + + +def convert_dist(g, op, block): + """Operator converter for dist.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + dtype = infer_type(x).checked_type.dtype + p = op.attr("p") + + x -= y + if p == np.inf: + out = _op.reduce.max(_op.abs(x)) + elif p == np.NINF: + out = _op.reduce.min(_op.abs(x)) + else: + reci_order = _expr.const(1.0 / p, dtype=dtype) + p = _expr.const(p) + out = _op.power( + _op.reduce.sum(_op.power(_op.abs(x), p)), + reci_order, + ) + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_dot(g, op, block): + """Operator converter for dot.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + out = _op.sum(_op.multiply(x, y), axis=[-1], keepdims=True) g.add_node(op.output("Out")[0], out) @@ -199,49 +717,59 @@ def convert_elementwise_op(g, op, block): """Operator converter for all the elementwise operators.""" op_map = { - "elementwise_div": lambda x, y: x / y, - "elementwise_add": lambda x, y: x + y, - "elementwise_mul": lambda x, y: x * y, - "elementwise_sub": lambda x, y: x - y, - "elementwise_mod": lambda x, y: x % y, + "elementwise_div": "divide", + "elementwise_add": "add", + "elementwise_mul": "multiply", + "elementwise_sub": "subtract", + "elementwise_mod": "mod", + "elementwise_max": "maximum", + "elementwise_min": "minimum", + "elementwise_pow": "power", + "elementwise_floordiv": "floor_divide", + "floor_mod": "floor_mod", + "equal": "equal", + "greater_equal": "greater_equal", + "greater_than": "greater", + "less_equal": "less_equal", + "less_than": "less", + "not_equal": "not_equal", } op_func = op_map[op.type] ipt0 = g.get_node(op.input("X")[0]) ipt1 = g.get_node(op.input("Y")[0]) - ipt0_shape = block.var(op.input("X")[0]).shape - ipt1_shape = block.var(op.input("Y")[0]).shape + ipt0_shape = infer_shape(ipt0) + ipt1_shape = infer_shape(ipt1) axis = op.attr("axis") if len(ipt0_shape) != len(ipt1_shape): if axis < 0: axis = axis + len(ipt0_shape) if axis != len(ipt0_shape) - 1: ipt1 = _op.expand_dims(ipt1, axis=axis, num_newaxis=(len(ipt0_shape) - axis - 1)) + op_func = get_relay_op(op_func) out = op_func(ipt0, ipt1) g.add_node(op.output("Out")[0], out) -def convert_equal(g, op, block): - """Operator converter for equal.""" +def convert_expand(g, op, block): + """Operator converter for expand.""" x = g.get_node(op.input("X")[0]) - y = g.get_node(op.input("Y")[0]) - out = _op.equal(x, y) + if op.input("Shape"): + sizes = g.get_node(op.input("Shape")[0]) + sizes = _infer_value(sizes, g.get_params()) + else: + sizes = op.attr("shape") + + out = _op.broadcast_to(x, sizes) g.add_node(op.output("Out")[0], out) -def convert_activation(g, op, block): - """Operator converter for all the activation.""" +def convert_expand_as(g, op, block): + """Operator converter for expand_as.""" - op_map = { - "exp": _op.exp, - "relu": _op.nn.relu, - "tanh": _op.tanh, - "sqrt": _op.sqrt, - "erf": _op.erf, - "abs": _op.abs, - } - act_func = op_map[op.type] - out = act_func(g.get_node(op.input("X")[0])) + x = g.get_node(op.input("X")[0]) + target_shape = op.attr("target_shape") + out = _op.broadcast_to(x, target_shape) g.add_node(op.output("Out")[0], out) @@ -259,6 +787,12 @@ def convert_feed(g, op, block): ipt_name = op.name if g.shape_dict is not None: ipt_shape = g.shape_dict[ipt_name] + + if isinstance(ipt_shape, tuple): + ipt_shape = list(ipt_shape) + for i, s in enumerate(ipt_shape): + if s < 0: + ipt_shape[i] = _ty.Any() out = new_var(ipt_name, shape=ipt_shape, dtype=ipt_dtype) g.add_node(ipt_name, out) @@ -266,18 +800,11 @@ def convert_feed(g, op, block): def convert_fill_any_like(g, op, block): """Operator converter for fill_any_like.""" - out_name = op.output("Out")[0] - out_dtype = block.var(out_name).dtype - out_dtype = str(out_dtype).strip().split(".")[1] + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) - ipt_type = infer_type(x).checked_type - value = op.attr("value") - if not _ty.is_dynamic(ipt_type): - shape = infer_shape(x) - const = np.ones(shape) * value - out = _expr.const(const.astype(out_dtype)) - else: - out = _op.transform.full_like(x, value).astype(out_dtype) + value = _expr.const(op.attr("value"), dtype=dtype) + out = _op.transform.full_like(x, value).astype(dtype) g.add_node(op.output("Out")[0], out) @@ -286,16 +813,100 @@ def convert_fill_constant(g, op, block): value = op.attr("value") shape = block.var(op.output("Out")[0]).shape - dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] - if op.input("ValueTensor"): + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) + value = _expr.const(value).astype(dtype) + if "ValueTensor" in op.input_names and op.input("ValueTensor"): shape = g.get_node(op.input("ValueTensor")[0]) - shape = infer_value(shape, g.get_params()).numpy() - if op.input("ShapeTensor"): + shape = _infer_value(shape, g.get_params()) + if "ShapeTensor" in op.input_names and op.input("ShapeTensor"): shape = g.get_node(op.input("ShapeTensor")[0]) - shape = infer_value(shape, g.get_params()).numpy() - value = np.full(shape, value, dtype) - out = _expr.const(value.astype(dtype)).astype(dtype) + shape = _infer_value(shape, g.get_params()) + + out = _op.full(value, shape=shape, dtype=dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_fill_constant_batch_size_like(g, op, block): + """Operator converter for fill_constant_batch_size_like.""" + + x = g.get_node(op.input("Input")[0]) + value = op.attr("value") + shape = op.attr("shape") + input_dim_idx = op.attr("input_dim_idx") + output_dim_idx = op.attr("output_dim_idx") + dtype = op.attr("dtype") + + dtype = _convert_dtype_value(dtype) + input_shape = shape_of(x) + batch = _op.strided_slice(input_shape, begin=[input_dim_idx], end=[input_dim_idx + 1]).astype( + "int32" + ) + shape_before = shape[:output_dim_idx] + shape_before = _expr.const(shape_before, dtype="int32") + shape_after = shape[output_dim_idx + 1 :] + shape_after = _expr.const(shape_after, dtype="int32") + + out_shape = _op.concatenate([shape_before, batch, shape_after], axis=0) + out_shape = _infer_value(out_shape, g.get_params()) + constant = _expr.const(value, dtype=dtype).astype(dtype) + out = _op.full(constant, out_shape, dtype=dtype) + + # reshape for dynamic + if isinstance(out_shape, _expr.Expr): + shape[output_dim_idx] = -1 + out = _op.reshape(out, shape) + + g.add_node(op.output("Out")[0], out) + + +def convert_flatten(g, op, block): + """Operator converter for flatten.""" + + x = g.get_node(op.input("X")[0]) + input_shape = list(infer_shape(x)) + + start = op.attr("start_axis") + end = op.attr("stop_axis") + ndim = len(input_shape) + if end < 0: + end += ndim + new_shape = [0] * start + + new_shape.append(-1) + squeeze_axes = [] + for i in range(start + 1, end + 1): + new_shape.append(1) + squeeze_axes.append(i) + for _ in range(end + 1, ndim): + new_shape.append(0) + out = _op.reshape(x, new_shape) + if squeeze_axes: + out = _op.squeeze(out, axis=squeeze_axes) + + g.add_node(op.output("Out")[0], out) + + +def convert_gather(g, op, block): + """Operator converter for gather.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + axis = op.attr("axis") + out = _op.take(x, index, axis) + g.add_node(op.output("Out")[0], out) + + +def convert_gather_nd(g, op, block): + """Operator converter for gather_nd.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + shape = infer_shape(index) + perm = list(range(0, len(shape) - 1)) + perm.insert(0, len(shape) - 1) + index = _op.transpose(index, axes=perm) + out = _op.gather_nd(x, index, 0, shape[-1]) g.add_node(op.output("Out")[0], out) @@ -310,12 +921,46 @@ def convert_gelu(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_group_norm(g, op, block): + """Operator converter for group_norm.""" + + x = g.get_node(op.input("X")[0]) + num_groups = op.attr("groups") + epsilon = op.attr("epsilon") + gamma = g.get_node(op.input("Scale")[0]) + beta = g.get_node(op.input("Bias")[0]) + out = _op.nn.group_norm( + x, + gamma=gamma, + beta=beta, + num_groups=num_groups, + axis=1, + epsilon=epsilon, + center=True, + scale=True, + ) + g.add_node(op.output("Y")[0], out) + + +def convert_hard_shrink(g, op, block): + """Operator converter for hard_shrink.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + threshold = op.attr("threshold") + threshold = _op.const(threshold, dtype) + out = _op.logical_or(x < _op.const(-1.0, dtype) * threshold, x > threshold) + out = _op.cast(out, dtype) * x + g.add_node(op.output("Out")[0], out) + + def convert_hard_sigmoid(g, op, block): """Operator converter for hard_sigmoid.""" slope = op.attr("slope") x = g.get_node(op.input("X")[0]) - out = x * _expr.const(slope) + _expr.const(0.5) + dtype = infer_type(x).checked_type.dtype + out = x * _expr.const(slope, dtype) + _expr.const(0.5, dtype) out = _op.clip(out, 0, 1) g.add_node(op.output("Out")[0], out) @@ -330,12 +975,47 @@ def convert_hard_swish(g, op, block): assert np.isclose(scale, 6.0), "Only support scale==6.0 for PaddlePaddle's hard_swish" assert np.isclose(threshold, 6.0), "Only support threshold==6.0 for PaddlePaddle's hard_swish" x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype out = _op.clip(x, -1 * offset, offset) - out = out / _expr.const(threshold) + _expr.const(0.5) + out = out / _expr.const(threshold, dtype) + _expr.const(0.5, dtype) out = x * out g.add_node(op.output("Out")[0], out) +def convert_hard_tanh(g, op, block): + """Operator converter for hard_tanh.""" + + x = g.get_node(op.input("X")[0]) + t_max = op.attr("t_max") + t_min = op.attr("t_min") + out = _op.tensor.clip(x, t_min, t_max) + g.add_node(op.output("Out")[0], out) + + +def convert_index_select(g, op, block): + """Operator converter for index_select.""" + + dim = op.attr("dim") + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + out = _op.take(x, indices=index, axis=dim, mode="clip") + + g.add_node(op.output("Out")[0], out) + + +def convert_instance_norm(g, op, block): + """Operator converter for instance_norm.""" + + x = g.get_node(op.input("X")[0]) + gamma = g.get_node(op.input("Scale")[0]) + beta = g.get_node(op.input("Bias")[0]) + epsilon = op.attr("epsilon") + + scale = center = True + out = _op.nn.instance_norm(x, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale) + g.add_node(op.output("Y")[0], out) + + def convert_layer_norm(g, op, block): """Operator converter for layer_norm.""" @@ -381,14 +1061,105 @@ def convert_lookup_table(g, op, block): indices = g.get_node(op.input("Ids")[0]) padding_idx = op.attr("padding_idx") - if padding_idx != -1: - g.get_params[op.input("W")[0]][padding_idx] = 0.0 - g.add_node(op.input("W")[0], _expr.const(g.params[op.input("W")[0]])) weights = g.get_node(op.input("W")[0]) + if padding_idx != -1: + if op.input("W")[0] in g.get_params(): + weights = g.get_params(op.input("W")[0]) + weights[padding_idx] = 0.0 + weights = _expr.const(weights) + else: + shape = _infer_value(shape_of(weights), g.get_params()) + assert not isinstance( + shape, _expr.Expr + ), "Shape of weight has to be fixed for PaddlePaddle's lookup_table" + filters = np.ones(shape).astype(infer_type(weights).checked_type.dtype) + filters[padding_idx] = 0.0 + filters = _expr.const(filters) + weights = weights * filters out = _op.take(weights, indices.astype("int32"), axis=0) g.add_node(op.output("Out")[0], out) +def convert_log1p(g, op, block): + """Operator converter for log1p.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + one = _expr.const(1, dtype=dtype) + out = _op.log(x + one) + g.add_node(op.output("Out")[0], out) + + +def convert_logical_op(g, op, block): + """Operator converter for logical op.""" + + ipt0 = g.get_node(op.input("X")[0]) + ipt1 = g.get_node(op.input("Y")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0, ipt1) + g.add_node(op.output("Out")[0], out) + + +def convert_logical_not(g, op, block): + """Operator converter for logical_not op.""" + + ipt0 = g.get_node(op.input("X")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0) + g.add_node(op.output("Out")[0], out) + + +def convert_logsigmoid(g, op, block): + """Operator converter for logsigmoid.""" + + x = g.get_node(op.input("X")[0]) + out = _op.log(_op.tensor.sigmoid(x)) + g.add_node(op.output("Out")[0], out) + + +def convert_logsoftmax(g, op, block): + """Operator converter for logsoftmax.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + ndim = len(infer_shape(x)) + if axis < 0: + axis += ndim + m = _op.max(x, [axis], keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, [axis], keepdims=True) + out = x - m - _op.log(s) + g.add_node(op.output("Out")[0], out) + + +def convert_logsumexp(g, op, block): + """Operator converter for logsumexp.""" + + input_x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + if op.attr("reduce_all"): + axis = None + keepdims = op.attr("keepdim") + out = get_relay_op("logsumexp")(input_x, axis=axis, keepdims=keepdims) + if not axis and not keepdims: + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_masked_select(g, op, block): + """Operator converter for masked_select.""" + + x = g.get_node(op.input("X")[0]) + mask = g.get_node(op.input("Mask")[0]) + index = _op.transform.argwhere(mask) + shape = infer_shape(index) + perm = list(range(0, len(shape) - 1)) + perm.insert(0, len(shape) - 1) + index = _op.transpose(index, axes=perm) + out = _op.gather_nd(x, index, 0, shape[-1]) + g.add_node(op.output("Y")[0], out) + + def convert_matmul(g, op, block): """Operator converter for matmul.""" @@ -498,6 +1269,16 @@ def flatten_to_nd(x, x_shape, nd=3): g.add_node(op.output("Out")[0], out) +def convert_meshgrid(g, op, block): + """Operator converter for meshgrid.""" + + inputs = op.input("X") + x = [g.get_node(i) for i in inputs] + outs = _op.meshgrid(x, indexing="ij") + for i, out in enumerate(outs): + g.add_node(op.output("Out")[i], out) + + def convert_mul(g, op, block): """Operator converter for mul.""" @@ -505,8 +1286,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = shape_of(x) - y_shape = shape_of(y) + x_shape = _op.shape_of(x) + y_shape = _op.shape_of(y) x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -543,6 +1324,37 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_mv(g, op, block): + """Operator converter for mv.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Vec")[0]) + y = _op.expand_dims(y, axis=-1) + y = _op.transpose(y) + out = _op.nn.dense(x, y) + out = _op.squeeze(out, axis=[-1]) + g.add_node(op.output("Out")[0], out) + + +def convert_numel(g, op, block): + """Operator converter for numel.""" + + input_x = g.get_node(op.input("Input")[0]) + out = _op.ndarray_size(input_x, dtype="int64") + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_nonzero(g, op, block): + """Operator converter for nonzero.""" + + input_x = g.get_node(op.input("Condition")[0]) + out = _op.transform.argwhere(input_x) + # Paddle NonZero always outputs int64 + out = _op.cast(out, "int64") + g.add_node(op.output("Out")[0], out) + + def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -558,7 +1370,7 @@ def convert_pool2d(g, op, block): ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - in_h, in_w = infer_shape(input_x)[2:] + _, _, in_h, in_w = infer_shape(input_x) op_map = { "avg": "avg_pool2d", @@ -575,8 +1387,19 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) + if strides[0] == 1 and strides[1] == 1: + pad_h = _get_pad_size(0, ksize[0], strides[0]) + pad_w = _get_pad_size(0, ksize[1], strides[1]) + else: + input_shape = shape_of(input_x) + h_w = _op.strided_slice(input_shape, [2], [4]) + try: + in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() + except Exception as e: + msg = "The SAME padding algorithm of Conv not support dynamic shape" + raise tvm.error.OpAttributeInvalid(msg) from e + pad_h = _get_pad_size(in_h, ksize[0], strides[0]) + pad_w = _get_pad_size(in_w, ksize[1], strides[1]) paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -587,6 +1410,11 @@ def convert_pool2d(g, op, block): msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + ksize[0] = in_h + if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + ksize[1] = in_w + if not adaptive: out = getattr(_op.nn, op_map[pooling_type])( input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode @@ -596,6 +1424,197 @@ def convert_pool2d(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_max_pool2d_with_index(g, op, block): + """Operator converter for max_pool2d_with_index.""" + + adaptive = op.attr("adaptive") + global_pooling = op.attr("global_pooling") + ksize = op.attr("ksize") + paddings = op.attr("paddings") + if global_pooling: + adaptive = True + ksize = [1, 1] + + input_x = g.get_node(op.input("X")[0]) + + strides = op.attr("strides") + if isinstance(strides, int): + strides = [strides, strides] + if isinstance(ksize, int): + ksize = [ksize, ksize] + if isinstance(paddings, int): + paddings = [paddings] * 2 + + if not adaptive: + out = getattr(_op.nn, "max_pool2d")( + input_x, pool_size=ksize, strides=strides, padding=paddings + ) + else: + out = getattr(_op.nn, "adaptive_max_pool2d")(input_x, output_size=ksize) + g.add_node(op.output("Out")[0], out) + + +def convert_padding(g, op, block): + """Operator converter for padding.""" + + input_x = g.get_node(op.input("X")[0]) + input_padding = op.input("Paddings") + if input_padding: + padding = g.get_node(input_padding[0]) + padding = infer_value(padding, g.get_params()).numpy().tolist() + else: + padding = op.attr("paddings") + padding = op.attr("paddings") + value = op.attr("value") + data_format = op.attr("data_format") + mode = op.attr("mode") + assert mode != "circular", "Don't support mod='circular' for PaddlePaddle's padding" + if mode == "replicate": + mode = "edge" + + pad_len = len(padding) + new_paddings = [0] * (pad_len + 4) + for i in range(0, pad_len, 2): + index = -1 - i + if data_format[:2] != "NC": + index = -3 - i + new_paddings[index] = padding[i + 1] + new_paddings[index - 1] = padding[i] + + new_paddings = [new_paddings[i : i + 2] for i in range(0, len(new_paddings), 2)] + + out = _op.nn.pad(input_x, new_paddings, pad_value=value, pad_mode=mode) + g.add_node(op.output("Out")[0], out) + + +def convert_pixel_shuffle(g, op, block): + """Operator converter for pixel_shuffle.""" + + x = g.get_node(op.input("X")[0]) + upscale_factor = op.attr("upscale_factor") + out = _op.nn.depth_to_space(x, upscale_factor, mode="CRD") + g.add_node(op.output("Out")[0], out) + + +def convert_pow(g, op, block): + """Operator converter for pow.""" + + x = g.get_node(op.input("X")[0]) + factor = op.attr("factor") + factor = _expr.const(factor, dtype="float32").astype("float32") + out = _op.power(x, factor) + g.add_node(op.output("Out")[0], out) + + +def convert_prelu(g, op, block): + """Operator converter for prelu.""" + + x = g.get_node(op.input("X")[0]) + alpha = g.get_node(op.input("Alpha")[0]) + ndims = len(infer_shape(x)) + axis = 0 if ndims <= 1 else 1 + mode = op.attr("mode") + if mode == "all": + if ndims == 1: + shape = _op.strided_slice(shape_of(x), [0], [1]) + else: + shape = _op.strided_slice(shape_of(x), [1], [2]) + alpha = _op.broadcast_to(alpha, shape) + out = _op.nn.prelu(x, alpha, axis) + g.add_node(op.output("Out")[0], out) + + +def convert_norm(g, op, block): + """Operator converter for norm.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + axis = op.attr("axis") + keepdim = op.attr("keepdim") + if op.attr("asvector"): + axis = None + order = op.attr("porder") + if order == np.inf: + out = _op.reduce.max(_op.abs(x), axis=axis, keepdims=keepdim) + elif order == np.NINF: + out = _op.reduce.min(_op.abs(x), axis=axis, keepdims=keepdim) + else: + reci_order = _expr.const(1.0 / order, dtype=dtype) + order = _expr.const(order) + out = _op.power( + _op.reduce.sum(_op.power(_op.abs(x), order), axis=axis, keepdims=keepdim), + reci_order, + ) + if op.attr("asvector") and not keepdim: + out = _op.expand_dims(out, axis=0) + + g.add_node(op.output("Out")[0], out) + + +def convert_range(g, op, block): + """Operator converter for range.""" + + start = g.get_node(op.input("Start")[0]) + stop = g.get_node(op.input("End")[0]) + step = g.get_node(op.input("Step")[0]) + dtype = infer_type(start).checked_type.dtype + + params = [] + for param in (start, stop, step): + param = _infer_value(param, g.get_params()) + if isinstance(param, list): + param = param[0] + if isinstance(param, _expr.Expr): + param = _op.squeeze(param) + else: + param = _op.const(param, dtype=dtype) + params.append(param) + + out = _op.transform.arange(params[0], params[1], params[2], dtype=dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_reciprocal(g, op, block): + """Operator converter for reciprocal.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = _expr.const(1.0, dtype) / x + g.add_node(op.output("Out")[0], out) + + +def convert_reduce(g, op, block): + """Operator converter for reduce.""" + + op_map = { + "reduce_all": "all", + "reduce_any": "any", + "reduce_max": "max", + "reduce_min": "min", + "reduce_prod": "prod", + "reduce_sum": "sum", + "reduce_mean": "mean", + } + op_name = op_map[op.type] + input_x = g.get_node(op.input("X")[0]) + axis = op.attr("dim") + if op.attr("reduce_all"): + axis = None + keepdims = op.attr("keep_dim") + out = get_relay_op(op_name)(input_x, axis=axis, keepdims=keepdims) + if not axis and not keepdims: + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_relu6(g, op, block): + """Operator converter for relu6.""" + + x = g.get_node(op.input("X")[0]) + out = _op.clip(x, 0.0, 6.0) + g.add_node(op.output("Out")[0], out) + + def convert_reshape(g, op, block): """Operator converter for reshape.""" @@ -605,24 +1624,304 @@ def convert_reshape(g, op, block): if input_shape: new_shape = g.get_node(input_shape[0]) elif input_shape_tensor: - tmp_shape = [] + new_shape = [] for shape_name in input_shape_tensor: shape = g.get_node(shape_name) if len(infer_shape(shape)) == 0: shape = _op.reshape(shape, [-1]) - if isinstance(shape, _expr.Constant): - tmp_shape.append(shape) - elif isinstance(shape, _expr.Expr): - tmp_shape.append(shape) - else: - tmp_shape.append(_expr.const(np.array(shape).astype("int64"))) - new_shape = _op.concatenate(tmp_shape, axis=0) + new_shape.append(shape.astype("int64")) + new_shape = _op.concatenate(new_shape, axis=0) + new_shape = _infer_value(new_shape, g.get_params()) else: new_shape = op.attr("shape") out = _op.reshape(data, new_shape) g.add_node(op.output("Out")[0], out) +def convert_rnn(g, op, block): + """Operator converter for rnn.""" + + def generate_lstm( + input_seqs, + hidden_state, + cell_state, + w_inp, + w_hid, + b_inp, + b_hid, + f_act, + g_act, + h_act, + backwards=False, + ): + """Implementation of LSTM cell for paddlepaddle of TVM""" + + h_list = [] + seq_length = len(input_seqs) + for i in range(seq_length): + step = input_seqs[i] if not backwards else input_seqs[seq_length - (i + 1)] + step = _op.squeeze(step, axis=[0]) + gates = _op.nn.dense(step, w_inp) + _op.nn.dense(hidden_state, w_hid) + if b_inp is not None: + gates += b_inp + if b_hid is not None: + gates += b_hid + i, f, c, o = _op.split(gates, 4, axis=-1) + + i = f_act(i) + f = f_act(f) + + c = g_act(c) + C = f * cell_state + i * c + + o = f_act(o) + + H = o * h_act(C) + + hidden_state = H + cell_state = C + h_list.append(_op.expand_dims(H, axis=0)) + + if backwards: + h_list = h_list[::-1] + + # Concatenate outputs and add back in direction axis. + concatenated = _op.concatenate(h_list, 0) + output = _op.expand_dims(concatenated, axis=1) + hidden_state = _op.expand_dims(hidden_state, axis=0) + cell_state = _op.expand_dims(cell_state, axis=0) + + return output, hidden_state, cell_state + + def generate_gru( + input_seqs, hidden_state, w_inp, w_hid, b_inp, b_hid, rz_act, n_act, backwards=False + ): + """Implementation of GRU cell for paddlepaddle of TVM""" + + h_list = [] + seq_length = len(input_seqs) + for i in range(seq_length): + step = input_seqs[i] if not backwards else input_seqs[seq_length - (i + 1)] + step = _op.squeeze(step, axis=[0]) + xwt = _op.nn.dense(step, w_inp) + hwt = _op.nn.dense(hidden_state, w_hid) + if b_inp is not None: + xwt += b_inp + if b_hid is not None: + hwt += b_hid + i_r, i_z, i_n = _op.split(xwt, 3, axis=-1) + h_r, h_z, h_n = _op.split(hwt, 3, axis=-1) + + r_gate = rz_act(i_r + h_r) + z_gate = rz_act(i_z + h_z) + n_gate = n_act(i_n + r_gate * h_n) + + hidden_state = (hidden_state - n_gate) * z_gate + n_gate + h_list.append(_op.expand_dims(hidden_state, axis=0)) + + if backwards: + h_list = h_list[::-1] + + # Concatenate outputs and add back in direction axis. + concatenated = _op.concatenate(h_list, 0) + output = _op.expand_dims(concatenated, axis=1) + hidden_state = _op.expand_dims(hidden_state, axis=0) + + return output, hidden_state + + def generate_simplernn( + input_seqs, hidden_state, w_inp, w_hid, b_inp, b_hid, n_act, backwards=False + ): + """Implementation of SimpleRNN cell for paddlepaddle of TVM""" + + h_list = [] + seq_length = len(input_seqs) + for i in range(seq_length): + step = input_seqs[i] if not backwards else input_seqs[seq_length - (i + 1)] + step = _op.squeeze(step, axis=[0]) + xwt = _op.nn.dense(step, w_inp) + hwt = _op.nn.dense(hidden_state, w_hid) + if b_inp is not None: + xwt += b_inp + if b_hid is not None: + hwt += b_hid + + n_gate = n_act(xwt + hwt) + + hidden_state = n_gate + h_list.append(_op.expand_dims(hidden_state, axis=0)) + + if backwards: + h_list = h_list[::-1] + + # Concatenate outputs and add back in direction axis. + concatenated = _op.concatenate(h_list, 0) + output = _op.expand_dims(concatenated, axis=1) + hidden_state = _op.expand_dims(hidden_state, axis=0) + + return output, hidden_state + + def make_param_inputs(g, node, layer, hidden_size, num_layers): + """Param for weight and bias.""" + + bidirect_len = 4 if node.attr("is_bidirec") else 2 + all_layer_param_len = len(node.input("WeightList")) + weight_list = node.input("WeightList")[: all_layer_param_len // 2] + bias_list = node.input("WeightList")[all_layer_param_len // 2 :] + + layer_weight_list = weight_list[layer * bidirect_len : layer * bidirect_len + bidirect_len] + layer_bias_list = bias_list[layer * bidirect_len : layer * bidirect_len + bidirect_len] + param_list = layer_weight_list + layer_bias_list + param_list_len = len(param_list) + + input_weights = param_list[0 : param_list_len // 2 : 2] + hidden_weights = param_list[1 : param_list_len // 2 : 2] + + input_bias = param_list[param_list_len // 2 : param_list_len : 2] + hidden_bias = param_list[param_list_len // 2 + 1 : param_list_len : 2] + + return input_weights, hidden_weights, input_bias, hidden_bias + + def make_init_param_inputs(g, node, layer): + """Init param for inputs.""" + + mode = node.attr("mode") + if mode == "LSTM": + all_init_h, all_init_c = node.input("PreState") + bidirect_len = 2 if node.attr("is_bidirec") else 1 + init_h = _op.strided_slice( + g.get_node(all_init_h), + [layer * bidirect_len], + [layer * bidirect_len + bidirect_len], + axes=[0], + ) + init_c = _op.strided_slice( + g.get_node(all_init_c), + [layer * bidirect_len], + [layer * bidirect_len + bidirect_len], + axes=[0], + ) + return init_h, init_c + all_init_h = node.input("PreState")[0] + bidirect_len = 2 if node.attr("is_bidirec") else 1 + init_h = _op.strided_slice( + g.get_node(all_init_h), + [layer * bidirect_len], + [layer * bidirect_len + bidirect_len], + axes=[0], + ) + return init_h + + hidden_size = op.attr("hidden_size") + num_layers = op.attr("num_layers") + is_bidirec = op.attr("is_bidirec") + mode = op.attr("mode") + + input_x = g.get_node(op.input("Input")[0]) + + num_directions = 1 + if is_bidirec: + num_directions = 2 + + x_shape = infer_shape(input_x) + time_steps = x_shape[0] + x_steps = _op.split(input_x, indices_or_sections=time_steps, axis=0) + for layer in range(num_layers): + input_weights, hidden_weights, input_bias, hidden_bias = make_param_inputs( + g, op, layer, hidden_size, num_layers + ) + if mode == "LSTM": + init_h, init_c = make_init_param_inputs(g, op, layer) + init_hs = _op.split(init_h, num_directions) + init_cs = _op.split(init_c, num_directions) + result_output = [] + result_H = [] + result_C = [] + for i in range(num_directions): + H_t = _op.squeeze(init_hs[i], axis=[0]) + C_t = _op.squeeze(init_cs[i], axis=[0]) + W = g.get_node(input_weights[i]) + R = g.get_node(hidden_weights[i]) + WB = g.get_node(input_bias[i]) + RB = g.get_node(hidden_bias[i]) + output, H, C = generate_lstm( + input_seqs=x_steps, + hidden_state=H_t, + cell_state=C_t, + w_inp=W, + w_hid=R, + b_inp=WB, + b_hid=RB, + f_act=_op.sigmoid, + g_act=_op.tanh, + h_act=_op.tanh, + backwards=i == 1, + ) + result_output.append(output) + result_H.append(H) + result_C.append(C) + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) + C = _op.concatenate(result_C, axis=0) + elif mode == "GRU": + init_h = make_init_param_inputs(g, op, layer) + init_hs = _op.split(init_h, num_directions) + result_output = [] + result_H = [] + for i in range(num_directions): + H_t = _op.squeeze(init_hs[i], axis=[0]) + W = g.get_node(input_weights[i]) + R = g.get_node(hidden_weights[i]) + WB = g.get_node(input_bias[i]) + RB = g.get_node(hidden_bias[i]) + output, H = generate_gru( + input_seqs=x_steps, + hidden_state=H_t, + w_inp=W, + w_hid=R, + b_inp=WB, + b_hid=RB, + rz_act=_op.sigmoid, + n_act=_op.tanh, + backwards=i == 1, + ) + result_output.append(output) + result_H.append(H) + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) + elif mode == "RNN_TANH": + init_h = make_init_param_inputs(g, op, layer) + init_hs = _op.split(init_h, num_directions) + result_output = [] + result_H = [] + for i in range(num_directions): + H_t = _op.squeeze(init_hs[i], axis=[0]) + W = g.get_node(input_weights[i]) + R = g.get_node(hidden_weights[i]) + WB = g.get_node(input_bias[i]) + RB = g.get_node(hidden_bias[i]) + output, H = generate_simplernn( + input_seqs=x_steps, + hidden_state=H_t, + w_inp=W, + w_hid=R, + b_inp=WB, + b_hid=RB, + n_act=_op.tanh, + backwards=i == 1, + ) + result_output.append(output) + result_H.append(H) + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) + + output = _op.transpose(output, axes=[0, 2, 1, 3]) + output = _op.reshape(output, newshape=(0, 0, -1)) + x_steps = _op.split(output, indices_or_sections=time_steps, axis=0) + + g.add_node(op.output("Out")[0], output) + + def convert_scale(g, op, block): """Operator converter for scale.""" @@ -631,8 +1930,11 @@ def convert_scale(g, op, block): bias_after_scale = op.attr("bias_after_scale") x = g.get_node(op.input("X")[0]) if np.isclose(scale, 1.0) and np.isclose(bias, 0.0): - out = _op.copy(x) + out = x else: + x_dtype = infer_type(x).checked_type.dtype + if x_dtype != "float32": + x = x.astype("float32") if np.isclose(bias, 0.0): out = x * _expr.const(np.array(scale).astype("float32")) elif np.isclose(scale, 1.0): @@ -646,6 +1948,58 @@ def convert_scale(g, op, block): out = (x + _expr.const(np.array(bias).astype("float32"))) * _expr.const( np.array(scale).astype("float32") ) + if x_dtype != "float32": + out = out.astype(x_dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_scatter(g, op, block): + """Operator converter for scatter.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Ids")[0]) + updates = g.get_node(op.input("Updates")[0]) + overwrite = op.attr("overwrite") + + shape = infer_shape(updates) + ndims = len(shape) + index = _op.expand_dims(index, axis=-1, num_newaxis=ndims - 1) + index = _op.transform.broadcast_to(index, shape) + + if overwrite: + out = _op.scatter(x, index, updates, axis=0) + else: + out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0) + out += _op.scatter(x, index, _op.zeros_like(updates), axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_scatter_nd_add(g, op, block): + """Operator converter for scatter_nd_add.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + updates = g.get_node(op.input("Updates")[0]) + indices_dim = len(infer_shape(index)) + axes = list(range(indices_dim)) + index = _op.transpose(index, axes[-1:] + axes[:-1]) + out = _op.scatter_nd(x, index, updates, mode="add") + g.add_node(op.output("Out")[0], out) + + +def convert_selu(g, op, block): + """Operator converter for selu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(op.attr("alpha"), dtype) + scale = _op.const(op.attr("scale"), dtype) + out = ( + _expr.const(-1.0, dtype=dtype) + * alpha + * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(x)) + ) + out = scale * (out + _op.nn.relu(x)) g.add_node(op.output("Out")[0], out) @@ -660,38 +2014,107 @@ def convert_shape(g, op, block): def convert_slice(g, op, block): """Operator converter for slice.""" - def parameter_process(starts, ends, axes, dshape): - new_axes = [] - new_starts = [] - new_ends = [] - pop_index = 0 - for i in range(max(axes) + 1): - new_axes.append(i) - if i in axes: - new_starts.append(starts[pop_index]) - new_ends.append(ends[pop_index]) - pop_index += 1 - else: - new_starts.append(0) - new_ends.append(dshape[i]) - return new_starts, new_ends, new_axes - data = g.get_node(op.input("Input")[0]) - dshape = infer_shape(data) - starts = op.attr("starts") - ends = op.attr("ends") + dims = len(infer_shape(data)) + axes = op.attr("axes") + indices = _expr.const(axes, dtype="int64") + decrease_axis = op.attr("decrease_axis") - if isinstance(starts, int): - starts = [starts] - if isinstance(ends, int): - ends = [ends] - if isinstance(axes, int): - axes = [axes] if isinstance(decrease_axis, int): decrease_axis = [decrease_axis] - starts, ends, axes = parameter_process(starts, ends, axes, dshape) - out = _op.strided_slice(data, begin=starts, end=ends) + + if op.input("StartsTensor"): + starts = g.get_node(op.input("StartsTensor")[0]) + starts = _infer_value(starts, g.get_params()) + elif op.input("StartsTensorList"): + starts = [] + for start_index in op.input("StartsTensorList"): + start_index = g.get_node(start_index).astype("int64") + starts.append(start_index) + starts = _op.concatenate(starts, axis=0) + starts = _infer_value(starts, g.get_params()) + else: + starts = op.attr("starts") + + if len(axes) < dims: + if isinstance(starts, _expr.Expr): + starts = _op.scatter( + _op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype), + indices, + starts, + axis=0, + ) + else: + base = [0] * dims + for i, axis in enumerate(axes): + base[axis] = starts[i] + starts = base + + if op.input("EndsTensor"): + ends = g.get_node(op.input("EndsTensor")[0]) + ends = _infer_value(ends, g.get_params()) + elif op.input("EndsTensorList"): + ends = [] + for end_index in op.input("EndsTensorList"): + end_index = g.get_node(end_index).astype("int64") + ends.append(end_index) + ends = _op.concatenate(ends, axis=0) + ends = _infer_value(ends, g.get_params()) + else: + ends = op.attr("ends") + + if len(axes) < dims: + if isinstance(ends, _expr.Expr): + ends = _op.scatter( + _expr.const( + np.array([np.iinfo(np.int32).max] * dims), + dtype=infer_type(ends).checked_type.dtype, + ), + indices, + ends, + axis=0, + ) + else: + base = [np.iinfo(np.int32).max] * dims + for i, axis in enumerate(axes): + base[axis] = ends[i] + ends = base + + strides = None + if "StridesTensor" in op.input_names and op.input("StridesTensor"): + strides = g.get_node(op.input("StridesTensor")[0]) + strides = _infer_value(strides, g.get_params()) + elif "StridesTensorList" in op.input_names and op.input("StridesTensorList"): + strides = [] + for strides_index in op.input("StridesTensorList"): + strides_index = g.get_node(strides_index).astype("int64") + strides.append(strides_index) + strides = _op.concatenate(strides, axis=0) + strides = _infer_value(strides, g.get_params()) + elif op.has_attr("strides"): + strides = op.attr("strides") + + if len(axes) < dims: + if isinstance(strides, _expr.Expr): + strides = _op.scatter( + _expr.const( + np.array([1] * dims), + dtype=infer_type(strides).checked_type.dtype, + ), + indices, + strides, + axis=0, + ) + elif strides: + base = [1] * dims + for i, axis in enumerate(axes): + base[axis] = strides[i] + strides = base + if not strides: + strides = _op.const([1] * dims, dtype="int64") + + out = _op.strided_slice(data, begin=starts, end=ends, strides=strides) if decrease_axis: out = _op.squeeze(out, axis=decrease_axis) g.add_node(op.output("Out")[0], out) @@ -711,6 +2134,184 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_softplus(g, op, block): + """Operator converter for softplus.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + beta = op.attr("beta") + beta = _expr.const(beta, dtype=dtype) + out = _op.log(_op.exp(x * beta) + _expr.const(1.0, dtype=dtype)) / beta + g.add_node(op.output("Out")[0], out) + + +def convert_softshrink(g, op, block): + """Operator converter for softshrink.""" + + x = g.get_node(op.input("X")[0]) + threshold = op.attr("lambda") + out = x - _op.clip(x, -1.0 * threshold, threshold) + g.add_node(op.output("Out")[0], out) + + +def convert_softsign(g, op, block): + """Operator converter for softsign.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = x / (_op.const(1.0, dtype) + _op.abs(x)) + g.add_node(op.output("Out")[0], out) + + +def convert_swish(g, op, block): + """Operator converter for swish.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = x / (_op.const(1.0, dtype) + _op.exp(_op.const(-1.0, dtype) * x)) + g.add_node(op.output("Out")[0], out) + + +def convert_tanhshrink(g, op, block): + """Operator converter for swish.""" + + x = g.get_node(op.input("X")[0]) + out = x - _op.tanh(x) + g.add_node(op.output("Out")[0], out) + + +def convert_thresholded_relu(g, op, block): + """Operator converter for thresholded_relu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + threshold = _op.const(op.attr("threshold"), dtype) + out = _op.where(_op.greater(x, threshold), x, _op.const(0.0, dtype)) + g.add_node(op.output("Out")[0], out) + + +def convert_split(g, op, block): + """Operator converter for split.""" + + x = g.get_node(op.input("X")[0]) + axis = op.input("AxisTensor") + if axis: + axis = g.get_node(axis[0]) + axis = infer_value(axis, g.get_params()).numpy().tolist()[0] + else: + axis = op.attr("axis") + + sections = op.input("SectionsTensorList") + if sections: + tmp_section = [] + for i in sections: + i = g.get_node(i) + i = infer_value(i, g.get_params()).numpy().tolist() + tmp_section.extend(i) + sections = tmp_section + else: + sections = op.attr("sections") + if sections: + indices = [] + split_index = 0 + for i in sections[:-1]: + if i == -1: + input_shape = infer_shape(x)[axis] + i = input_shape - np.sum(sections) - 1 + split_index += i + indices.append(split_index) + else: + indices = op.attr("num") + + out = _op.split(x, indices, axis) + for i, out_i in enumerate(out): + g.add_node(op.output("Out")[i], out_i) + + +def convert_square(g, op, block): + """Operator converter for square.""" + + x = g.get_node(op.input("X")[0]) + out = _op.multiply(x, x) + g.add_node(op.output("Out")[0], out) + + +def convert_squeeze(g, op, block): + """Operator converter for squeeze2.""" + + x = g.get_node(op.input("X")[0]) + axes = op.attr("axes") + if not axes: + axes = None + x = _op.squeeze(x, axis=axes) + g.add_node(op.output("Out")[0], x) + + +def convert_topk(g, op, block): + """Operator converter for topk.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + largest = op.attr("largest") + is_ascend = not bool(largest) + k_node = op.input("K") + if k_node: + k_node = g.get_node(k_node[0]) + k = _infer_value(k_node, g.get_params()) + else: + k = op.attr("k") + outs = _op.topk(x, k=k, axis=axis, is_ascend=is_ascend, ret_type="both", dtype="int64") + + g.add_node(op.output("Out")[0], outs[0]) + g.add_node(op.output("Indices")[0], outs[1]) + + +def convert_stack(g, op, block): + """Operator converter for stack.""" + + inputs = op.input("X") + inputs = [g.get_node(i) for i in inputs] + axis = op.attr("axis") + inputs = _dtype_shape_promotion(inputs) + out = _op.stack(inputs, axis) + g.add_node(op.output("Y")[0], out) + + +def convert_tile(g, op, block): + """Operator converter for tile.""" + + input_x = g.get_node(op.input("X")[0]) + repeat_times = op.input("RepeatTimes") + repeat_times_tensor = op.input("repeat_times_tensor") + if repeat_times: + repeat_times = g.get_node(repeat_times[0]) + elif repeat_times_tensor: + tmp_shape = [] + for shape_name in repeat_times_tensor: + shape = g.get_node(shape_name) + if len(infer_shape(shape)) == 0: + shape = _op.reshape(shape, [-1]) + if isinstance(shape, _expr.Constant): + tmp_shape.append(shape) + elif isinstance(shape, _expr.Expr): + tmp_shape.append(shape) + else: + tmp_shape.append(_expr.const(np.array(shape).astype("int32"))) + repeat_times = _op.concatenate(tmp_shape, axis=0) + else: + repeat_times = op.attr("repeat_times") + out = _op.tile(input_x, repeat_times) + g.add_node(op.output("Out")[0], out) + + +def convert_transpose(g, op, block): + """Operator converter for transpose.""" + + perm = op.attr("axis") + out = _op.transpose(g.get_node(op.input("X")[0]), axes=perm) + g.add_node(op.output("Out")[0], out) + + def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -721,64 +2322,240 @@ def convert_unsqueeze(g, op, block): g.add_node(op.output("Out")[0], x) +def convert_unstack(g, op, block): + """Operator converter for unstack.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + num = op.attr("num") + out = _op.split(x, num, axis=axis) + for i, out_i in enumerate(out): + out_i = _op.squeeze(out_i, axis=[axis]) + g.add_node(op.output("Y")[i], out_i) + + +def convert_unique(g, op, block): + """Operator converter for unique.""" + + x = g.get_node(op.input("X")[0]) + ndim = len(infer_shape(x)) + assert ndim == 1, "Only support 1D Tensor for PaddlePaddle's unique" + is_sorted = op.attr("is_sorted") + return_counts = op.attr("return_counts") + return_index = op.attr("return_index") + return_inverse = op.attr("return_inverse") + if return_counts: + [unique, indices, inverse_indices, num_uniq, counts] = _op.unique( + x, is_sorted=is_sorted, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + indices_sliced = _op.strided_slice(indices, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.cast(counts_sliced, "int64") + g.add_node(op.output("Counts")[0], counts_sliced) + else: + [unique, indices, inverse_indices, num_uniq] = _op.unique( + x, is_sorted=is_sorted, return_counts=False + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + indices_sliced = _op.strided_slice(indices, begin=[0], end=num_uniq, slice_mode="size") + + inverse_indices = _op.cast(inverse_indices, "int64") + indices_sliced = _op.cast(indices_sliced, "int64") + g.add_node(op.output("Out")[0], unique_sliced) + if return_index: + g.add_node(op.output("Indices")[0], indices_sliced) + if return_inverse: + g.add_node(op.output("Index")[0], inverse_indices) + + +def convert_where(g, op, block): + """Operator converter for where.""" + + condition = g.get_node(op.input("Condition")[0]) + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + out = _op.where(condition, x, y) + g.add_node(op.output("Out")[0], out) + + _convert_map = { + "abs": convert_unary_op, + "acos": convert_unary_op, + "addmm": convert_addmm, "arg_max": convert_arg_max, + "arg_min": convert_arg_min, + "argsort": convert_argsort, + "asin": convert_unary_op, "assign": convert_assign, + "assign_value": convert_assign_value, + "atan": convert_unary_op, "batch_norm": convert_batch_norm, + "bicubic_interp_v2": convert_interpolate, + "bilinear_interp_v2": convert_interpolate, + "bmm": convert_bmm, + "brelu": convert_hard_tanh, "cast": convert_cast, + "ceil": convert_unary_op, + "clip": convert_clip, "concat": convert_concat, "conv2d": convert_conv2d, + "conv2d_transpose": convert_conv2d_transpose, + "cos": convert_unary_op, + "cosh": convert_unary_op, + "crop_tensor": convert_crop, "cumsum": convert_cumsum, "depthwise_conv2d": convert_conv2d, + "dist": convert_dist, + "dot": convert_dot, "dropout": convert_dropout, "elementwise_add": convert_elementwise_op, "elementwise_div": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, - "equal": convert_equal, - "exp": convert_activation, + "elementwise_mod": convert_elementwise_op, + "elementwise_max": convert_elementwise_op, + "elementwise_min": convert_elementwise_op, + "elementwise_pow": convert_elementwise_op, + "elementwise_floordiv": convert_elementwise_op, + "elu": convert_elu, + "equal": convert_elementwise_op, + "erf": convert_unary_op, + "exp": convert_unary_op, + "expand_v2": convert_expand, + "expand_as_v2": convert_expand_as, "feed": convert_feed, "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, + "fill_constant_batch_size_like": convert_fill_constant_batch_size_like, + "flatten_contiguous_range": convert_flatten, + "floor": convert_unary_op, + "floor_mod": convert_elementwise_op, + "gather": convert_gather, + "gather_nd": convert_gather_nd, "gelu": convert_gelu, + "greater_equal": convert_elementwise_op, + "greater_than": convert_elementwise_op, + "group_norm": convert_group_norm, + "hard_shrink": convert_hard_shrink, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, + "index_select": convert_index_select, + "isfinite": convert_unary_op, + "isfinite_v2": convert_unary_op, + "instance_norm": convert_instance_norm, + "isinf": convert_unary_op, + "isinf_v2": convert_unary_op, + "isnan": convert_unary_op, + "isnan_v2": convert_unary_op, "layer_norm": convert_layer_norm, "leaky_relu": convert_leaky_relu, + "less_equal": convert_elementwise_op, + "less_than": convert_elementwise_op, + "lookup_table": convert_lookup_table, "lookup_table_v2": convert_lookup_table, + "log": convert_unary_op, + "log2": convert_unary_op, + "log10": convert_unary_op, + "log1p": convert_log1p, + "logical_and": convert_logical_op, + "logical_not": convert_logical_not, + "logical_or": convert_logical_op, + "logical_xor": convert_logical_op, + "logsigmoid": convert_logsigmoid, + "log_softmax": convert_logsoftmax, + "logsumexp": convert_logsumexp, + "masked_select": convert_masked_select, "matmul": convert_matmul, "matmul_v2": convert_matmul, + "meshgrid": convert_meshgrid, + "mv": convert_mv, "mul": convert_mul, + "nearest_interp_v2": convert_interpolate, + "not_equal": convert_elementwise_op, "pool2d": convert_pool2d, - "relu": convert_activation, + "max_pool2d_with_index": convert_max_pool2d_with_index, + "pad1d": convert_padding, + "pad2d": convert_padding, + "pad3d": convert_padding, + "pixel_shuffle": convert_pixel_shuffle, + "pow": convert_pow, + "prelu": convert_prelu, + "p_norm": convert_norm, + "range": convert_range, + "reciprocal": convert_reciprocal, + "reduce_all": convert_reduce, + "reduce_any": convert_reduce, + "reduce_max": convert_reduce, + "reduce_min": convert_reduce, + "reduce_prod": convert_reduce, + "reduce_sum": convert_reduce, + "reduce_mean": convert_reduce, + "relu": convert_unary_op, + "relu6": convert_relu6, "reshape2": convert_reshape, + "rnn": convert_rnn, + "round": convert_unary_op, + "rsqrt": convert_unary_op, "scale": convert_scale, + "scatter": convert_scatter, + "scatter_nd_add": convert_scatter_nd_add, + "selu": convert_selu, "shape": convert_shape, + "sigmoid": convert_unary_op, + "sign": convert_unary_op, + "sin": convert_unary_op, + "sinh": convert_unary_op, + "size": convert_numel, "slice": convert_slice, "softmax": convert_softmax, - "tanh": convert_activation, + "softplus": convert_softplus, + "softshrink": convert_softshrink, + "softsign": convert_softsign, + "split": convert_split, + "sqrt": convert_unary_op, + "square": convert_square, + "squeeze2": convert_squeeze, + "stack": convert_stack, + "strided_slice": convert_slice, + "sum": convert_addn, + "swish": convert_swish, + "tan": convert_unary_op, + "tanh": convert_unary_op, + "tanh_shrink": convert_tanhshrink, + "thresholded_relu": convert_thresholded_relu, + "top_k_v2": convert_topk, + "tile": convert_tile, + "transpose2": convert_transpose, "unsqueeze2": convert_unsqueeze, + "unstack": convert_unstack, + "unique": convert_unique, + "where": convert_where, + "where_index": convert_nonzero, } class GraphProto: """A helper class for handling relay functions from PaddlePaddle model.""" - def __init__(self): + def __init__(self, freeze_params=False): self.nodes = {} self.params = {} self.shape_dict = None + self.freeze_params = freeze_params def get_node(self, name): """get node from graph""" - assert name in self.nodes + assert name in self.nodes, "Node: {} not found".format(name) return self.nodes[name] def add_node(self, name, node): """add a node to graph""" - - self.nodes[name] = fold_constant(node) + if self.shape_dict: + self.nodes[name] = fold_constant(node) + else: + self.nodes[name] = node def get_params(self, name=None): """get params from graph""" @@ -788,6 +2565,11 @@ def get_params(self, name=None): assert name in self.params return self.params[name] + def set_params(self, params): + """set params for graph""" + + self.params = params + def extract_parameters(self, program, scope=None): """Extract all the weights from PaddlePaddle program.""" @@ -803,7 +2585,12 @@ def extract_parameters(self, program, scope=None): self.params[name] = scope[name] else: self.params[name] = np.array(scope.var(name).get_tensor()) - self.nodes[name] = _expr.const(self.params[name]) + if self.freeze_params: + self.nodes[name] = _expr.const(self.params[name]) + else: + self.nodes[name] = _expr.var( + name, shape=self.params[name].shape, dtype=str(self.params[name].dtype) + ) def check_input_shape(self, op, block): """Check the shape information of model's inputs, fixed shape is recommended.""" @@ -826,6 +2613,8 @@ def check_unsupported_ops(self, program): for op in block.ops: if op.type == "fetch": continue + if op.type in ControlFlow.operators: + continue if op.type not in _convert_map: unsupported_ops.add(op.type) if len(unsupported_ops) > 0: @@ -839,12 +2628,15 @@ def ops_to_relay(self, program, input_specs=None): if input_specs is not None: for input_spec in input_specs: convert_feed(self, input_spec, None) - for block in program.blocks: - for op in block.ops: - if op.type == "fetch": - continue + global_block = program.blocks[0] + for op in global_block.ops: + if op.type == "fetch": + continue + if op.type in ControlFlow.operators: + ControlFlow.convert(self, op, program) + else: convert_func = _convert_map[op.type] - convert_func(self, op, block) + convert_func(self, op, global_block) def from_program(self, program, shape_dict, scope): """Construct the TVM relay expression from PaddlePaddle program.""" @@ -864,12 +2656,14 @@ def from_program(self, program, shape_dict, scope): if op.type == "fetch": output_names.append(op.input("X")[0]) - outputs = [self.nodes[name] for name in output_names] + outputs = [self.get_node(name) for name in output_names] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + if self.freeze_params: + self.params = {} return mod, self.params def from_translated_layer(self, layer, shape_dict): @@ -888,25 +2682,27 @@ def from_translated_layer(self, layer, shape_dict): output_names = [x.name for x in layer._output_spec()] - outputs = [self.nodes[name] for name in output_names] + outputs = [self.get_node(name) for name in output_names] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + if self.freeze_params: + self.params = {} return mod, self.params -def from_paddle(program_or_layer, shape_dict=None, scope=None): +def from_paddle(program_or_layer, shape_dict=None, scope=None, freeze_params=False): """Convert a PaddlePaddle model into an equivalent Relay Function. - - PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, - and PaddlePaddle scope stores all the weights of PaddlePaddle model. + PaddlePaddle Program/TranslatedLayer represent the computation + graph of PaddlePaddle model, and PaddlePaddle scope stores all the + weights of PaddlePaddle model. """ import paddle - g = GraphProto() + g = GraphProto(freeze_params) if isinstance(program_or_layer, paddle.jit.TranslatedLayer): # model is loaded by `paddle.jit.load` mod, params = g.from_translated_layer(program_or_layer, shape_dict) diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index db07e07f9d83..00fe97f3146d 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -14,19 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os from pathlib import Path import shutil import numpy as np + +import paddle +from paddle.framework import dtype +import paddle.nn as nn + import tvm import tvm.testing import tvm.topi.testing from tvm import relay from tvm.contrib import graph_executor -import paddle -import paddle.nn as nn PADDLE_TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data", "paddle") PADDLE_TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) @@ -34,14 +36,16 @@ def assert_shapes_match(tru, est): if tru.shape != est.shape: - msg = "Output shapes {} and {} don't match" + msg = "Paddle Output shapes {} and TVM shapes {} don't match" raise AssertionError(msg.format(tru.shape, est.shape)) + if tru.dtype != est.dtype: + msg = "Paddle Output dtype {} and TVM dtype {} don't match" + raise AssertionError(msg.format(tru.dtype, est.dtype)) def get_paddle_model(func, input_spec): global PADDLE_TEST_DATA_ROOT_PATH model_path = Path(PADDLE_TEST_DATA_ROOT_PATH, "model") - paddle.jit.save(func, str(model_path), input_spec=input_spec) baseline_model = paddle.jit.load(str(model_path)) @@ -49,7 +53,34 @@ def get_paddle_model(func, input_spec): return baseline_model -def verify_model(func, input_data, rtol=1e-5, atol=1e-5): +def get_tvm_output_with_vm(mod, params, target, device, input_data): + """Generic function to execute and get tvm output with vm executor""" + + ex = relay.create_executor("vm", mod=mod, device=device, target=target) + params.update(input_data) + result = ex.evaluate()(**params) + if isinstance(result, tvm.runtime.NDArray): + return [ + result.numpy(), + ] + return [r.numpy() for r in result] + + +def get_tvm_output(mod, params, target, device, input_data, compiled_names, num): + """Generic function to execute and get tvm output""" + + lib = relay.build(mod, target=target, params=params) + gmod = graph_executor.GraphModule(lib["default"](device)) + for name in compiled_names: + gmod.set_input(name, input_data[name]) + gmod.run() + outputs = [] + for i in range(num): + outputs.append(gmod.get_output(i).numpy()) + return outputs + + +def verify_model(func, input_data, rtol=1e-5, atol=1e-5, input_shape=None): if not (isinstance(input_data, (tuple, list))): input_data = [input_data] @@ -59,11 +90,13 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): compiled_input = {} for idx, data in enumerate(input_data): input_name = "input{}".format(idx) - input_spec.append( - paddle.static.InputSpec(dtype=data.dtype, shape=data.shape, name=input_name) - ) + if input_shape: + shape = input_shape[idx] + else: + shape = data.shape + input_shape_dict[input_name] = shape + input_spec.append(paddle.static.InputSpec(dtype=data.dtype, shape=shape, name=input_name)) input_names.append(input_name) - input_shape_dict[input_name] = data.shape if isinstance(data, np.ndarray): compiled_input[input_name] = data else: @@ -81,25 +114,74 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) parms_num = min(len(input_names), len(mod["main"].params)) compiled_names = [] - for arg in mod["main"].params[:parms_num]: - assert arg.name_hint in input_names - compiled_names.append(arg.name_hint) + for arg in mod["main"].params: + assert arg.name_hint in input_names or arg.name_hint in params + if arg.name_hint in input_names: + compiled_names.append(arg.name_hint) with tvm.transform.PassContext(opt_level=3): for target, dev in tvm.testing.enabled_targets(): - lib = relay.build(mod, target=target, params=params) - gmod = graph_executor.GraphModule(lib["default"](dev)) - for name in compiled_names: - gmod.set_input(name, compiled_input[name]) - gmod.run() - - for i, baseline_output in enumerate(baseline_outputs): - compiled_output = gmod.get_output(i).numpy() - + if input_shape: + tvm_output = get_tvm_output_with_vm(mod, params, target, dev, compiled_input) + else: + tvm_output = get_tvm_output( + mod, params, target, dev, compiled_input, compiled_names, len(baseline_outputs) + ) + + for baseline_output, compiled_output in zip(baseline_outputs, tvm_output): assert_shapes_match(baseline_output, compiled_output) tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) +@tvm.testing.uses_gpu +def test_forward_math(): + class MathOp(nn.Layer): + def __init__(self, op_name): + super(MathOp, self).__init__() + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, op_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, inputs): + return self.func(inputs) + + input_data = paddle.rand([1, 2, 5, 5], dtype="float32") + op_list = [ + "abs", + "acos", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", + "exp", + "floor", + "log", + "log2", + "log10", + "log1p", + "numel", + "reciprocal", + "relu", + "round", + "rsqrt", + "sigmoid", + "sign", + "rsqrt", + "sin", + "sinh", + "square", + "sqrt", + "tan", + "tanh", + ] + for op_name in op_list: + verify_model(MathOp(op_name), input_data) + + @tvm.testing.uses_gpu def test_forward_add_subtract(): input_shape = [10] @@ -124,6 +206,76 @@ def add_subtract3(inputs1, inputs2): verify_model(add_subtract3, [input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_addmm(): + @paddle.jit.to_static + def addmm(input, x, y, alpha=1, beta=1): + return paddle.addmm(input, x, y, alpha, beta) + + input_shape = [10, 10] + x_shape = [10, 3] + y_shape = [3, 10] + input_data = paddle.rand(input_shape, dtype="float32") + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(addmm, input_data=[input_data, x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_addn(): + @paddle.jit.to_static + def addn(a, b, c): + return paddle.add_n([a, b, c]) + + @paddle.jit.to_static + def addn2(a, b): + return paddle.add_n([a, b]) + + @paddle.jit.to_static + def addn3(a): + return paddle.add_n([a]) + + input_shape = [1, 3, 10, 10] + a = paddle.rand(input_shape, dtype="float32") + b = paddle.rand(input_shape, dtype="float32") + c = paddle.rand(input_shape, dtype="float32") + verify_model(addn, [a, b, c]) + verify_model( + addn2, + [ + a, + b, + ], + ) + verify_model( + addn3, + [ + a, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_arange(): + @paddle.jit.to_static + def arange(inputs): + return paddle.arange(paddle.shape(inputs)[0], 9, 2.0) + + @paddle.jit.to_static + def arange2(inputs): + return paddle.arange(0, 10.0, paddle.shape(inputs)[1], dtype="float32") + + @paddle.jit.to_static + def arange3(inputs): + return paddle.arange(0, inputs, 2.0, dtype="float32") + + input_shape = [2, 2] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(arange, input_data) + verify_model(arange2, input_data=input_data) + verify_model(arange3, paddle.to_tensor(np.array([10], dtype="int32")), input_shape=[[1]]) + + @tvm.testing.uses_gpu def test_forward_argmax(): input_shape = [1, 3, 10, 10] @@ -156,26 +308,65 @@ def forward(self, inputs): @tvm.testing.uses_gpu -def test_forward_assign(): +def test_forward_argmin(): + input_shape = [1, 3, 10, 10] + + class ArgMin(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argmin(inputs) + + class ArgMin1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=1) + + class ArgMin2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=1, keepdim=False) + + class ArgMin3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=2, keepdim=True) + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMin(), input_data=input_data) + verify_model(ArgMin1(), input_data=input_data) + verify_model(ArgMin2(), input_data=input_data) + verify_model(ArgMin3(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_argsort(): + @paddle.jit.to_static + def argsort(inputs): + return paddle.argsort(inputs) + @paddle.jit.to_static - def assign(inputs): - return paddle.assign(inputs) + def argsort2(inputs): + return paddle.argsort(inputs, axis=0, descending=True) + + input_shape = [2, 3, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(argsort, input_data) + input_data2 = np.random.randint(100, size=input_shape) + verify_model(argsort2, input_data2) + + +@tvm.testing.uses_gpu +def test_forward_assign(): + class Assign(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.assign(inputs) input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") - verify_model( - assign, - [ - input_data, - ], - ) + verify_model(Assign(), [input_data]) input_data2 = np.random.randint(100, size=input_shape) - verify_model( - assign, - [ - input_data2, - ], - ) + verify_model(Assign(), [input_data2], input_shape=[[-1, -1]]) @tvm.testing.uses_gpu @@ -227,18 +418,36 @@ def cast2(inputs, dtype="int64"): input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") * 100 - verify_model( - cast1, - [ - input_data, - ], - ) - verify_model( - cast2, - [ - input_data, - ], - ) + verify_model(cast1, [input_data]) + verify_model(cast2, [input_data]) + + +@tvm.testing.uses_gpu +def test_forward_clip(): + @paddle.jit.to_static + def clip(inputs): + return paddle.clip(inputs, min=3, max=5) + + @paddle.jit.to_static + def clip2(inputs, max_value): + return paddle.clip(inputs, max=max_value) + + @paddle.jit.to_static + def clip3(inputs, min_value): + return paddle.clip(inputs, min=min_value) + + @paddle.jit.to_static + def clip4(inputs, min_value, max_value): + return paddle.clip(inputs, min=min_value, max=max_value) + + verify_model(clip, paddle.to_tensor([[1, 2], [4, 6]], dtype="int32")) + x = np.array([[1.2, 3.5], [4.5, 6.4]]) + x1 = paddle.to_tensor(x, dtype="float32") + min_value = paddle.to_tensor(np.array([2.1]), dtype="float32") + max_value = paddle.to_tensor(np.array([4.5]), dtype="float32") + verify_model(clip2, [x1, max_value]) + verify_model(clip3, [x1, min_value]) + verify_model(clip4, [x1, min_value, max_value]) @tvm.testing.uses_gpu @@ -261,40 +470,60 @@ def concat_unsqueeze2(inputs): @tvm.testing.uses_gpu -def test_forward_cumsum(): +def test_forward_crop(): + @paddle.jit.to_static + def crop1(inputs): + return paddle.crop(inputs, shape=[2, 2]) + @paddle.jit.to_static - def cusum1(inputs): - return paddle.cumsum(inputs) + def crop2(inputs, shape): + return paddle.crop(inputs, shape=shape, offsets=[0, 1]) @paddle.jit.to_static - def cusum2(inputs): - return paddle.cumsum(inputs, axis=0) + def crop3(inputs): + offsets = paddle.to_tensor(np.array([1, 0]).astype("int32")) + return paddle.crop(inputs, shape=[3, 3], offsets=offsets) @paddle.jit.to_static - def cusum3(inputs): - return paddle.cumsum(inputs, axis=1) + def crop4(inputs, shape, offsets): + return paddle.crop(inputs, shape=shape, offsets=offsets) + + input_shape = [10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(crop1, input_data=[input_data]) + shape = paddle.to_tensor(np.array([3, 3], "int32")) + verify_model(crop2, [input_data, shape], input_shape=[[-1, -1], [2]]) + verify_model(crop3, input_data=[input_data]) + offsets = paddle.to_tensor(np.array([1, 1]).astype("int32")) + verify_model(crop4, input_data=[input_data, shape, offsets], input_shape=[[-1, -1], [2], [2]]) + + +@tvm.testing.uses_gpu +def test_forward_cumsum(): + class Cumsum1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cumsum(inputs) + + class Cumsum2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cumsum(inputs, axis=0) + + class Cumsum3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cumsum(inputs, axis=1) input_data = paddle.randint(0, 100, (10, 10), dtype=paddle.int32) - verify_model(cusum1, [input_data]) - verify_model(cusum1, [input_data.astype(paddle.int64)]) - verify_model( - cusum2, - [ - input_data, - ], - ) - verify_model( - cusum3, - [ - input_data, - ], - ) + verify_model(Cumsum1(), input_data) + verify_model(Cumsum1(), [input_data.astype(paddle.int64)]) + verify_model(Cumsum2(), input_data) + verify_model(Cumsum3(), input_data) @tvm.testing.uses_gpu def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] - class Conv2D1(nn.Layer): def __init__(self): super(Conv2D1, self).__init__() @@ -315,9 +544,124 @@ def __init__(self): def forward(self, inputs): return self.softmax(self.conv(inputs)) + class Conv2D3(nn.Layer): + def __init__(self): + super(Conv2D3, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False, padding="SAME") + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv(inputs) + + class Conv2D4(nn.Layer): + def __init__(self): + super(Conv2D4, self).__init__() + self.conv = nn.Conv2D( + 3, 6, 7, groups=3, bias_attr=False, padding=[1, 2, 0, 1], stride=2, dilation=2 + ) + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv(inputs) + + conv2d_input_shape = [1, 3, 112, 112] conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") verify_model(Conv2D1(), input_data=conv2d_input_data) verify_model(Conv2D2(), input_data=conv2d_input_data) + verify_model(Conv2D3(), input_data=conv2d_input_data) + verify_model(Conv2D4(), input_data=conv2d_input_data) + verify_model(Conv2D1(), conv2d_input_data, input_shape=[[-1, 3, 112, 112]]) + + +@tvm.testing.uses_gpu +def test_forward_conv_transpose(): + # Note we do not test with groups > 1 because that is not supported + # in tvm for conv transpose operations + + class Conv2DTranspose1(nn.Layer): + def __init__(self): + super(Conv2DTranspose1, self).__init__() + self.conv_transpose = nn.Conv2DTranspose(3, 5, 3) + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv_transpose(inputs) + + class Conv2DTranspose2(nn.Layer): + def __init__(self): + super(Conv2DTranspose2, self).__init__() + self.conv_transpose = nn.Conv2DTranspose( + 3, + 5, + 3, + stride=2, + padding=[[0, 0], [0, 0], [1, 2], [3, 4]], + output_padding=1, + bias_attr=True, + ) + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv_transpose(inputs) + + class Conv2DTranspose3(nn.Layer): + def __init__(self): + super(Conv2DTranspose3, self).__init__() + self.conv_transpose = nn.Conv2DTranspose( + 3, 5, 3, stride=3, padding="VALID", output_padding=2, bias_attr=True + ) + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv_transpose(inputs) + + # Conv 2D Transpose Tests + conv2d_transpose_input_shape = [1, 3, 128, 256] + conv2d_transpose_input_data = paddle.rand(conv2d_transpose_input_shape, dtype="float32") + verify_model(Conv2DTranspose1(), input_data=conv2d_transpose_input_data) + verify_model(Conv2DTranspose2(), input_data=conv2d_transpose_input_data) + verify_model(Conv2DTranspose3(), input_data=conv2d_transpose_input_data) + + +@tvm.testing.uses_gpu +def test_forward_dist(): + @paddle.jit.to_static + def dist(x, y): + return paddle.dist(x, y, p=2) + + @paddle.jit.to_static + def dist2(x, y): + return paddle.dist(x, y, p=20) + + @paddle.jit.to_static + def dist3(x, y): + return paddle.dist(x, y, p=float("-inf")) + + @paddle.jit.to_static + def dist4(x, y): + return paddle.dist(x, y, p=float("inf")) + + x_shape = [10, 3] + y_shape = [10, 1] + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(dist, input_data=[x_data, y_data]) + verify_model(dist2, input_data=[x_data, y_data]) + verify_model(dist3, input_data=[x_data, y_data]) + verify_model(dist4, input_data=[x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_dot(): + @paddle.jit.to_static + def dot(x, y): + return paddle.dot(x, y) + + x_shape = [10, 3] + y_shape = [10, 3] + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(dot, input_data=[x_data, y_data]) @tvm.testing.uses_gpu @@ -332,19 +676,70 @@ def dropout(inputs): verify_model(dropout, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_expand(): + @paddle.jit.to_static + def expand1(inputs): + return paddle.expand(inputs, shape=[2, 3]) + + @paddle.jit.to_static + def expand2(inputs, shape): + return paddle.expand(inputs, shape=shape) + + x_shape = [3] + x_data = paddle.rand(x_shape, dtype="float32") + verify_model(expand1, input_data=[x_data]) + shape = paddle.to_tensor(np.array([2, 3]).astype("int32")) + verify_model(expand2, [x_data, shape], input_shape=[[3], [2]]) + + +@tvm.testing.uses_gpu +def test_forward_expand_as(): + @paddle.jit.to_static + def expand_as(x, y): + z = paddle.expand_as(x, y) + z += y + return z + + data_x = paddle.to_tensor([1, 2, 3], dtype="int32") + data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32") + verify_model(expand_as, [data_x, data_y]) + + +@tvm.testing.uses_gpu +def test_forward_flatten(): + @paddle.jit.to_static + def flatten1(inputs): + return paddle.flatten(inputs, start_axis=1, stop_axis=2) + + def flatten2(inputs): + return paddle.flatten(inputs, start_axis=0, stop_axis=-1) + + x_shape = [3, 100, 100, 4] + x_data = paddle.rand(x_shape, dtype="float32") + verify_model(flatten1, input_data=[x_data]) + verify_model(flatten2, input_data=[x_data]) + + @tvm.testing.uses_gpu def test_forward_shape_full(): @paddle.jit.to_static def full1(inputs): - return paddle.full(paddle.shape(inputs), 3.14) + return paddle.full(inputs, 3.14) @paddle.jit.to_static def full2(inputs): return paddle.full(paddle.shape(inputs), 1.0, dtype=inputs.dtype) + @paddle.jit.to_static + def shape1(inputs): + return paddle.shape(inputs) + input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(full1, input_data=[input_data]) + verify_model(shape1, input_data=[input_data]) + shape = paddle.to_tensor(np.array(input_shape, "int32")) + verify_model(full1, input_data=[shape], input_shape=[[4]]) verify_model(full2, input_data=[input_data]) @@ -361,75 +756,357 @@ def ones_like2(inputs): input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") verify_model(ones_like1, input_data=input_data) - verify_model(ones_like2, input_data=input_data) + verify_model(ones_like2, input_data, input_shape=[[-1, -1, -1, -1]]) @tvm.testing.uses_gpu -def test_forward_gelu(): +def test_forward_ones(): @paddle.jit.to_static - def gelu(inputs): - return nn.functional.gelu(inputs) + def ones1(inputs): + ones = paddle.ones([1, 3, 10, 10]) + out = inputs + ones + return out + + @paddle.jit.to_static + def ones2(inputs): + shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32") + ones = paddle.ones(shape) + out = inputs + ones + return out input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(gelu, input_data=input_data) + verify_model(ones1, input_data=input_data) + verify_model(ones2, input_data=input_data) -@tvm.testing.uses_gpu -def test_forward_hard_sigmoid(): - @paddle.jit.to_static - def hard_sigmoid(inputs): - return nn.functional.hardsigmoid(inputs) +def test_forward_elemwise(): + class ElemwiseOp(nn.Layer): + def __init__(self, op_name): + super(ElemwiseOp, self).__init__() + self.op_name_ = op_name + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, op_name, None) + if self.func: + break - input_shape = [1, 3, 10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(hard_sigmoid, input_data=input_data) + @paddle.jit.to_static + def forward(self, input1, input2): + y = self.func(input1, input2) + if "equal" in self.op_name_ or "than" in self.op_name_: + y = paddle.cast(y, "int32") + return y + + op_list = [ + "floor_divide", + "floor_mod", + "maximum", + "minimum", + "equal", + "greater_equal", + "greater_than", + "less_equal", + "less_than", + "not_equal", + ] + input_shape = [10, 10] + input_shape_2 = [ + 10, + ] + x_data = paddle.rand(input_shape, dtype="float32") + y_data = paddle.rand(input_shape_2, dtype="float32") + x_data_2 = paddle.randint(1, 100, input_shape_2, dtype="int32") + y_data_2 = paddle.randint(1, 100, input_shape, dtype="int32") + for op_name in op_list: + if op_name not in ["floor_divide"]: + verify_model(ElemwiseOp(op_name), [x_data, y_data]) + verify_model(ElemwiseOp(op_name), [x_data_2, y_data_2]) @tvm.testing.uses_gpu -def test_forward_hard_swish(): +def test_forward_gather_assign_value(): @paddle.jit.to_static - def hard_swish(inputs): - return nn.functional.hardswish(inputs) + def gather1(x, index): + return paddle.gather(x, index, axis=None) - input_shape = [1, 3, 10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(hard_swish, input_data=input_data) + @paddle.jit.to_static + def gather2(x): + index = paddle.to_tensor(np.array([1, 3, 5, 7, 9]).astype("int64")) + return paddle.gather(x, index, axis=1) + + x_shape = [30, 40] + x_data = paddle.rand(x_shape, dtype="float32") + index = paddle.to_tensor(np.array([1, 3, 5, 7, 9]).astype("int64")) + verify_model(gather1, [x_data, index], input_shape=[[30, 40], [5]]) + verify_model(gather2, input_data=[x_data]) @tvm.testing.uses_gpu -def test_forward_layer_norm(): +def test_forward_gather_nd(): @paddle.jit.to_static - def layer_norm(inputs, weight, bias): - return nn.functional.layer_norm(inputs, inputs.shape[-1], weight=weight, bias=bias) + def gather_nd1(x, index): + return paddle.gather_nd(x, index) - class LayerNorm(nn.Layer): - def __init__(self): - super(LayerNorm, self).__init__() - data_shape = [10] - self.layer_norm = nn.LayerNorm(data_shape) - - @paddle.jit.to_static - def forward(self, inputs): - return self.layer_norm(inputs) + @paddle.jit.to_static + def gather_nd2(x): + index = paddle.to_tensor(np.array([[0, 1], [1, 2]]).astype("int32")) + return paddle.gather_nd(x, index) - input_shape = [1, 3, 10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - weight = paddle.rand([10], dtype="float32") - bias = paddle.rand([10], dtype="float32") - verify_model(layer_norm, input_data=[input_data, weight, bias]) - verify_model(LayerNorm(), input_data=input_data) + x_shape = [30, 40, 20] + x_data = paddle.rand(x_shape, dtype="float32") + index = paddle.to_tensor(np.array([[0, 1]]).astype("int64")) + verify_model(gather_nd1, [x_data, index], input_shape=[[-1, 40, 20], [1, 2]]) + verify_model(gather_nd2, input_data=[x_data]) @tvm.testing.uses_gpu -def test_forward_leaky_relu(): +def test_forward_gelu(): @paddle.jit.to_static - def leaky_relu(inputs): - return nn.functional.leaky_relu(inputs) + def gelu(inputs): + return nn.functional.gelu(inputs) input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(leaky_relu, input_data=input_data) + verify_model(gelu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_group_norm(): + class GroupNorm(nn.Layer): + def __init__(self, channels, groups): + super(GroupNorm, self).__init__() + self.group_norm = paddle.nn.GroupNorm(num_channels=channels, num_groups=groups) + + def forward(self, inputs): + return self.group_norm(inputs) + + input_shape = [2, 6, 10, 10] + x = paddle.rand(input_shape, dtype="float32") + verify_model(GroupNorm(6, 6), x, rtol=1e-4, atol=1e-4) + + +@tvm.testing.uses_gpu +def test_forward_activation(): + class Activation(nn.Layer): + def __init__(self, op_name): + super(Activation, self).__init__() + self.op_name_ = op_name + for candidate in (paddle.nn.functional, paddle): + self.func = getattr(candidate, op_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, inputs): + return self.func(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.normal(shape=input_shape) * 10.0 + input_data_2 = paddle.normal(shape=input_shape).astype("float64") * 10.0 + op_list = [ + "elu", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "log_sigmoid", + "log_softmax", + "relu6", + "selu", + "sigmoid", + "softplus", + "softshrink", + "softsign", + "swish", + "tanhshrink", + "thresholded_relu", + ] + for op_name in op_list: + verify_model(Activation(op_name), input_data=input_data) + verify_model(Activation(op_name), input_data=input_data_2) + + +@tvm.testing.uses_gpu +def test_forward_index_select(): + @paddle.jit.to_static + def index_select1(x, index): + return paddle.index_select(x, index) + + @paddle.jit.to_static + def index_select2(x, index): + return paddle.index_select(x, index, axis=1) + + input_shape = [3, 10] + input_data = paddle.rand(input_shape, dtype="float32") + index = paddle.to_tensor(np.array([0, 1, 1]).astype("int32")) + verify_model(index_select1, input_data=[input_data, index]) + verify_model(index_select2, input_data=[input_data, index]) + + +@tvm.testing.uses_gpu +def test_forward_isfinite(): + @paddle.jit.to_static + def isfinite(inputs): + return paddle.cast(paddle.isfinite(inputs), "int32") + + input_shape = [5, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(isfinite, input_data=input_data) + + +def test_forward_instance_norm(): + class InstanceNorm(nn.Layer): + def __init__(self): + super(InstanceNorm, self).__init__() + self.instance_norm = paddle.nn.InstanceNorm2D(2) + + def forward(self, inputs): + return self.instance_norm(inputs) + + input_shape = [2, 2, 2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(InstanceNorm(), input_data) + + +@tvm.testing.uses_gpu +def test_forward_isinf(): + @paddle.jit.to_static + def isinf(inputs): + return paddle.cast(paddle.isinf(inputs), "int32") + + input_shape = [5, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(isinf, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_isnan(): + @paddle.jit.to_static + def isnan(inputs): + return paddle.cast(paddle.isnan(inputs), "int32") + + input_shape = [5, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(isnan, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_interpolate(): + class TestBilinear(nn.Layer): + def __init__(self): + super(TestBilinear, self).__init__() + self.conv = nn.Conv2D(3, 5, 3, stride=2) + + def forward(self, x, size): + y = self.conv(x) + return nn.functional.interpolate(y, size=size, mode="nearest") + + def bilinear_interp1(inputs): + return nn.functional.interpolate(inputs, size=[12, 12], mode="bilinear") + + @paddle.jit.to_static + def bilinear_interp2(inputs): + return nn.functional.interpolate( + inputs, scale_factor=[2.0, 1.0], mode="bilinear", align_corners=True, align_mode=1 + ) + + @paddle.jit.to_static + def bilinear_interp3(inputs): + return nn.functional.interpolate(inputs, scale_factor=[1.0, 2.0], mode="bicubic") + + @paddle.jit.to_static + def bilinear_interp4(inputs): + return nn.functional.interpolate( + inputs, scale_factor=3.0, mode="bicubic", align_corners=True, align_mode=0 + ) + + input_shape = [2, 3, 6, 12] + input_data = paddle.rand(input_shape, dtype="float32") + size = paddle.to_tensor(np.array([15, 15], "int32")) + verify_model(TestBilinear(), [input_data, size], input_shape=[input_shape, [2]]) + verify_model(bilinear_interp1, input_data=input_data) + verify_model(bilinear_interp2, input_data=input_data) + verify_model(bilinear_interp3, input_data=input_data) + verify_model(bilinear_interp4, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_layer_norm(): + @paddle.jit.to_static + def layer_norm(inputs, weight, bias): + return nn.functional.layer_norm(inputs, inputs.shape[-1], weight=weight, bias=bias) + + class LayerNorm(nn.Layer): + def __init__(self): + super(LayerNorm, self).__init__() + data_shape = [10] + self.layer_norm = nn.LayerNorm(data_shape) + + @paddle.jit.to_static + def forward(self, inputs): + return self.layer_norm(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + weight = paddle.rand([10], dtype="float32") + bias = paddle.rand([10], dtype="float32") + verify_model(layer_norm, input_data=[input_data, weight, bias]) + verify_model(LayerNorm(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_leaky_relu(): + @paddle.jit.to_static + def leaky_relu(inputs): + return nn.functional.leaky_relu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(leaky_relu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_logical_op(): + class LogicalOp(nn.Layer): + def __init__(self, op_name, out=False): + super(LogicalOp, self).__init__() + self.out = out + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, op_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, x, y): + if self.out: + out = paddle.to_tensor([True, True, True]) + z = self.func(x, y, out=out) + else: + z = self.func(x, y) + return paddle.cast(z, "int32") + + class LogicalOp_not(LogicalOp): + @paddle.jit.to_static + def forward(self, x): + if self.out: + out = paddle.to_tensor([True, True, True]) + z = self.func(x, out=out) + else: + z = self.func(x) + return paddle.cast(z, "int32") + + op_list = [ + "logical_or", + "logical_xor", + "logical_and", + ] + x = paddle.to_tensor([True]) + y = paddle.to_tensor([True, False, True, False]) + for op_name in op_list: + verify_model(LogicalOp(op_name, False), [x, y]) + verify_model(LogicalOp(op_name, True), [x, y]) + verify_model(LogicalOp_not("logical_not", False), [y]) + verify_model(LogicalOp_not("logical_not", True), [y]) @tvm.testing.uses_gpu @@ -451,7 +1128,105 @@ def forward(self, inputs): input_data = paddle.randint(0, 10, input_shape, dtype="int32") weight = paddle.rand([10, 4], dtype="float32") verify_model(look_up, input_data=[input_data, weight]) - verify_model(LookUp(), input_data=input_data) + verify_model(LookUp(), input_data, input_shape=[[-1, -1, -1, -1]]) + + +@tvm.testing.uses_gpu +def test_forward_lstm(): + class LSTM1(nn.Layer): + def __init__(self): + super(LSTM1, self).__init__() + self.lstm = nn.LSTM(288, 48, 2, direction="bidirect", time_major=True) + + @paddle.jit.to_static + def forward(self, inputs, prev_h, prev_c): + y, (h, c) = self.lstm(inputs, (prev_h, prev_c)) + return y + + class LSTM2(nn.Layer): + def __init__(self): + super(LSTM2, self).__init__() + self.lstm = nn.LSTMCell(16, 32) + + @paddle.jit.to_static + def forward(self, inputs, prev_h, prev_c): + y, (h, c) = self.lstm(inputs, (prev_h, prev_c)) + return y + + lstm_input_shape = [25, 1, 288] + lstm_input_data = paddle.rand(lstm_input_shape, dtype="float32") + prev_h = paddle.rand([4, 1, 48], dtype="float32") + prev_c = paddle.rand([4, 1, 48], dtype="float32") + verify_model(LSTM1(), input_data=[lstm_input_data, prev_h, prev_c]) + lstm_input_shape = [4, 16] + lstm_input_data = paddle.rand(lstm_input_shape, dtype="float32") + prev_h = paddle.rand([4, 32], dtype="float32") + prev_c = paddle.rand([4, 32], dtype="float32") + verify_model(LSTM2(), input_data=[lstm_input_data, prev_h, prev_c]) + + +@tvm.testing.uses_gpu +def test_forward_gru(): + class GRU1(nn.Layer): + def __init__(self): + super(GRU1, self).__init__() + self.gru = nn.GRU(288, 48, 2, direction="bidirect", time_major=True) + + @paddle.jit.to_static + def forward(self, inputs, prev_h): + y, h = self.gru(inputs, prev_h) + return y + + class GRU2(nn.Layer): + def __init__(self): + super(GRU2, self).__init__() + self.gru = nn.GRUCell(16, 32) + + @paddle.jit.to_static + def forward(self, inputs, prev_h): + y, h = self.gru(inputs, prev_h) + return y + + gru_input_shape = [25, 1, 288] + gru_input_data = paddle.rand(gru_input_shape, dtype="float32") + prev_h = paddle.rand([4, 1, 48], dtype="float32") + verify_model(GRU1(), input_data=[gru_input_data, prev_h]) + gru_input_shape = [4, 16] + gru_input_data = paddle.rand(gru_input_shape, dtype="float32") + prev_h = paddle.rand([4, 32], dtype="float32") + verify_model(GRU2(), input_data=[gru_input_data, prev_h]) + + +@tvm.testing.uses_gpu +def test_forward_simplernn(): + class SimpleRNN1(nn.Layer): + def __init__(self): + super(SimpleRNN1, self).__init__() + self.simplernn = nn.SimpleRNN(288, 48, 2, direction="bidirect", time_major=True) + + @paddle.jit.to_static + def forward(self, inputs, prev_h): + y, h = self.simplernn(inputs, prev_h) + return y + + class SimpleRNN2(nn.Layer): + def __init__(self): + super(SimpleRNN2, self).__init__() + self.simplernn = nn.SimpleRNNCell(16, 32) + + @paddle.jit.to_static + def forward(self, inputs, prev_h): + y, h = self.simplernn(inputs, prev_h) + return y + + gru_input_shape = [25, 1, 288] + gru_input_data = paddle.rand(gru_input_shape, dtype="float32") + prev_h = paddle.rand([4, 1, 48], dtype="float32") + verify_model(SimpleRNN1(), input_data=[gru_input_data, prev_h]) + gru_input_shape = [4, 16] + gru_input_data = paddle.rand(gru_input_shape, dtype="float32") + prev_h = paddle.rand([4, 32], dtype="float32") + verify_model(SimpleRNN2(), input_data=[gru_input_data, prev_h]) @tvm.testing.uses_gpu @@ -477,6 +1252,32 @@ def multiply3(inputs, inputs2): verify_model(multiply3, input_data=[input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_masked_select(): + @paddle.jit.to_static + def masked_select(x): + mask_data = np.array( + [[True, False, False, False], [True, True, False, False], [True, False, False, False]] + ).astype("bool") + mask = paddle.to_tensor(mask_data) + mask = paddle.logical_not(mask) + return paddle.masked_select(x, mask) + + @paddle.jit.to_static + def masked_select2(x, mask): + return paddle.masked_select(x, mask) + + data = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]).astype( + "float32" + ) + x = paddle.to_tensor(data) + verify_model(masked_select, x) + input_shape = [2, 3, 10] + x = paddle.rand(input_shape, dtype="float32") + mask = paddle.randint(0, 2, input_shape).astype("bool") + verify_model(masked_select2, [x, mask], input_shape=[input_shape, input_shape]) + + @tvm.testing.uses_gpu def test_forward_matmul(): class MatMul1(nn.Layer): @@ -504,6 +1305,147 @@ def forward(self, input1, input2): verify_model(MatMul1(), input_data=[input_data1, input_data2]) +@tvm.testing.uses_gpu +def test_forward_meshgrid(): + @paddle.jit.to_static + def t(x, y, z): + return paddle.meshgrid(x, y, z) + + x = paddle.randint(low=0, high=100, shape=[2]) + y = paddle.randint(low=0, high=100, shape=[3]) + z = paddle.randint(low=0, high=100, shape=[5]) + verify_model(t, [x, y, z]) + + +def test_forward_mm(): + class Mm(nn.Layer): + def forward(self, input1, input2): + return paddle.mm(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # matrix x matrix + input_data1 = paddle.randn((5, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # batched matrix x batched matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((10, 4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # batched matrix x broadcasted matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_mv(): + class Mv(nn.Layer): + def forward(self, input1, input2): + return paddle.mv(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(Mv(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_nonzero(): + class Nonzero(nn.Layer): + def __init__(self, as_tuple=False): + super().__init__() + self.as_tuple = as_tuple + + @paddle.jit.to_static + def forward(self, inputs): + return paddle.nonzero(inputs, self.as_tuple) + + x1 = paddle.to_tensor([[1.0, 0.0, 0.0, 2.0], [0.0, 2.0, 0.0, 1.1], [0.0, 0.0, 3.0, 0.0]]) + verify_model(Nonzero(), x1, input_shape=[[3, 4]]) + verify_model(Nonzero(True), x1, input_shape=[[3, 4]]) + x2 = paddle.to_tensor([0, 1, 0, 3]) + verify_model( + Nonzero(), + x2, + input_shape=[ + [ + 4, + ] + ], + ) + + +def test_forward_norm(): + class Norm1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("inf"), axis=None, keepdim=False) + + class Norm2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("-inf"), axis=None, keepdim=False) + + class Norm3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("-inf"), axis=None, keepdim=True) + + class Norm4(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("inf"), axis=[1, 2], keepdim=False) + + class Norm5(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("inf"), axis=-1, keepdim=True) + + class Norm6(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(0.5), axis=1, keepdim=True) + + class Norm7(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(1), axis=None, keepdim=False) + + class Norm8(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(2.0), axis=1, keepdim=False) + + class Norm9(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(-0.5), axis=[1, 2], keepdim=False) + + class Norm10(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(-2), axis=(1), keepdim=False) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Norm1(), input_data=input_data) + verify_model(Norm2(), input_data=input_data) + verify_model(Norm3(), input_data=input_data) + verify_model(Norm4(), input_data=input_data) + verify_model(Norm5(), input_data=input_data) + verify_model(Norm6(), input_data=input_data) + verify_model(Norm7(), input_data=input_data) + verify_model(Norm8(), input_data=input_data) + verify_model(Norm9(), input_data=input_data) + verify_model(Norm10(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_pool2d(): @paddle.jit.to_static @@ -516,32 +1458,222 @@ def pool2d2(inputs): @paddle.jit.to_static def pool2d3(inputs): - return nn.functional.max_pool2d( + output = nn.functional.max_pool2d(inputs, kernel_size=2, stride=2, padding=0) + return output + + @paddle.jit.to_static + def pool2d4(inputs): + output, max_indices = nn.functional.max_pool2d( inputs, kernel_size=2, stride=2, padding=0, return_mask=True ) + return output input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data=input_data) + verify_model(pool2d1, input_data, input_shape=[[-1, 2, 32, 32]]) verify_model(pool2d2, input_data=input_data) - # verify_model(pool2d3, input_data=input_data) + input_data1 = paddle.uniform(shape=[1, 2, 1, 50], dtype="float32", min=-1, max=1) + verify_model(pool2d3, input_data=input_data1) + # need op max_pool2d_with_index + verify_model(pool2d4, input_data=input_data) @tvm.testing.uses_gpu -def test_forward_relu(): +def test_forward_pad(): + class Pad1(nn.Layer): + def __init__(self): + super(Pad1, self).__init__() + self.pad = nn.Pad3D(padding=[1, 2, 3, 4, 5, 6], mode="replicate", value=0.5) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad(inputs) + @paddle.jit.to_static - def relu(inputs): - return nn.functional.relu(inputs) + def pad2(inputs): + return paddle.nn.functional.pad( + inputs, [1, 3, 1, 4, 1, 0], mode="constant", value=2.2, data_format="NDHWC" + ) - input_shape = [10, 10] + @paddle.jit.to_static + def pad3(inputs): + return paddle.nn.functional.pad( + inputs, [2, 3, 1, 0], mode="reflect", value=2.0, data_format="NCHW" + ) + + @paddle.jit.to_static + def pad4(inputs): + return paddle.nn.functional.pad( + inputs, [2, 1], mode="replicate", value=2.0, data_format="NLC" + ) + + input_data = paddle.rand([2, 3, 6, 7, 8], dtype="float32") + verify_model(Pad1(), input_data=input_data) + verify_model(pad2, input_data=input_data) + input_data = paddle.rand([2, 4, 3, 5], dtype="float32") + verify_model(pad3, input_data=input_data) + input_data = paddle.rand([2, 4, 5], dtype="float32") + verify_model(pad4, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_pixel_shuffle(): + class PixelShuffle(nn.Layer): + def __init__(self, upscale_factor): + super(PixelShuffle, self).__init__() + self.pixel_shuffle = paddle.nn.PixelShuffle(upscale_factor) + + @paddle.jit.to_static + def forward(self, x): + return self.pixel_shuffle(x) + + x = paddle.rand([2, 9, 5, 5], dtype="float32") + verify_model(PixelShuffle(3), x) + x2 = paddle.rand([3, 8, 9, 9], dtype="float32") + verify_model(PixelShuffle(2), x2) + + +@tvm.testing.uses_gpu +def test_forward_prelu(): + class PRelu(nn.Layer): + @paddle.jit.to_static + def forward(self, x, w): + return paddle.nn.functional.prelu(x, w) + + x = paddle.normal(shape=[4, 3, 5, 5]) + w = paddle.to_tensor( + np.array( + [ + 0.25, + ] + ).astype("float32") + ) + verify_model(PRelu(), [x, w]) + w2 = paddle.to_tensor(np.array([0.25, 0.5, 0.8]).astype("float32")) + verify_model(PRelu(), [x, w2]) + + +@tvm.testing.uses_gpu +def test_forward_pow(): + class Pow(nn.Layer): + @paddle.jit.to_static + def forward(self, x): + output = paddle.pow(x, 2) + return output + + class Pow1(nn.Layer): + @paddle.jit.to_static + def forward(self, x): + output = paddle.pow(x, 2.5) + return output + + class Pow2(nn.Layer): + @paddle.jit.to_static + def forward(self, x, y): + output = paddle.pow(x, y) + return output + + x_data = paddle.to_tensor([1, 2, 3], dtype="float32") + y_data = paddle.to_tensor([2], dtype="float32") + verify_model(Pow(), input_data=[x_data]) + verify_model(Pow1(), input_data=[x_data]) + verify_model(Pow2(), input_data=[x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_rank(): + class Rank(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + rank = paddle.rank(inputs) + rank = paddle.unsqueeze(rank, axis=0) + output = inputs + rank + return output + + input_shape = [1, 2, 1, 3, 1] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(relu, input_data=input_data) + verify_model(Rank(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_reduce_op(): + class ReduceOp_Bool(nn.Layer): + def __init__(self, op_name): + super(ReduceOp_Bool, self).__init__() + self.func = getattr(paddle, op_name, None) + + @paddle.jit.to_static + def forward(self, inputs): + inputs = paddle.cast(inputs, "bool") + output = self.func(inputs) + output = paddle.cast(output, "int32") + return output + + class ReduceOp_Bool1(ReduceOp_Bool): + @paddle.jit.to_static + def forward(self, inputs): + inputs = paddle.cast(inputs, "bool") + output = self.func(inputs, axis=0) + output = paddle.cast(output, "int32") + return output + + class ReduceOp_Bool2(ReduceOp_Bool): + @paddle.jit.to_static + def forward(self, inputs): + inputs = paddle.cast(inputs, "bool") + output = self.func(inputs, axis=[0, 1, -1], keepdim=True) + output = paddle.cast(output, "int32") + return output + + class ReduceOp_Math(nn.Layer): + def __init__(self, op_name): + super(ReduceOp_Math, self).__init__() + self.func = getattr(paddle, op_name, None) + + @paddle.jit.to_static + def forward(self, inputs): + output = self.func(inputs) + return output + + class ReduceOp_Math1(ReduceOp_Math): + @paddle.jit.to_static + def forward(self, inputs): + output = self.func(inputs, axis=0) + return output + + class ReduceOp_Math2(ReduceOp_Math): + @paddle.jit.to_static + def forward(self, inputs): + output = self.func(inputs, axis=[0, 1], keepdim=True) + return output + + input_data = paddle.randn([1, 2, 3]) + op_list_bool = [ + "all", + "any", + ] + for op_name in op_list_bool: + verify_model(ReduceOp_Bool(op_name), input_data) + verify_model(ReduceOp_Bool1(op_name), input_data) + verify_model(ReduceOp_Bool2(op_name), input_data) + input_data1 = paddle.rand([2, 4, 5], dtype="float32") + op_list_math = [ + "max", + "min", + "prod", + "sum", + "mean", + "logsumexp", + ] + for op_name in op_list_math: + verify_model(ReduceOp_Math(op_name), input_data1) + verify_model(ReduceOp_Math1(op_name), input_data1) + verify_model(ReduceOp_Math2(op_name), input_data1) @tvm.testing.uses_gpu def test_forward_reshape(): @paddle.jit.to_static - def reshape1(inputs, x): - new_shape = paddle.shape(x) + def reshape1(inputs, new_shape): return paddle.reshape(inputs, new_shape) @paddle.jit.to_static @@ -551,7 +1683,7 @@ def reshape2(inputs): @paddle.jit.to_static def reshape3(inputs): data_shape = inputs.shape - return inputs.reshape([data_shape[0] * data_shape[1], data_shape[2]]) + return inputs.reshape([data_shape[1], data_shape[2], data_shape[0]]) @paddle.jit.to_static def reshape4(inputs, x): @@ -561,7 +1693,8 @@ def reshape4(inputs, x): input_shape = [2, 1, 10, 1, 10] input_data = paddle.rand(input_shape, dtype="float32") input_data2 = paddle.randn([2, 1, 10, 10]) - verify_model(reshape1, input_data=[input_data, input_data2]) + new_shape = paddle.shape(input_data2) + verify_model(reshape1, [input_data, new_shape], input_shape=[[2, 1, 10, 1, 10], [4]]) verify_model(reshape2, input_data=input_data) verify_model(reshape3, input_data=paddle.randn((2, 3, 4))) verify_model(reshape4, input_data=[input_data, input_data2]) @@ -587,11 +1720,55 @@ def scale2(inputs): verify_model(scale2, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_scatter(): + @paddle.jit.to_static + def scatter(x, index, updates): + return paddle.scatter(x, index, updates, overwrite=True) + + @paddle.jit.to_static + def scatter2(x, index, updates): + return paddle.scatter(x, index, updates, overwrite=False) + + x = paddle.rand([10, 8, 5], dtype="float32") + index = paddle.to_tensor( + [ + 2, + 1, + 0, + 6, + ] + ) + updates = paddle.rand([4, 8, 5], dtype="float32") + verify_model(scatter, [x, index, updates], input_shape=[[-1, 8, 5], [4], [4, 8, 5]]) + verify_model(scatter2, [x, index, updates]) + + +def test_forward_scatter_nd(): + @paddle.jit.to_static + def scatter_nd(index, updates): + shape = [3, 5, 9, 10] + return paddle.scatter_nd(index, updates, shape) + + @paddle.jit.to_static + def scatter_nd_add(x, index, updates): + return paddle.scatter_nd_add(x, index, updates) + + index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) + index = paddle.to_tensor(index_data) + updates = paddle.rand(shape=[3, 9, 10], dtype="float32") + verify_model(scatter_nd, [index, updates]) + x = paddle.rand(shape=[3, 5, 4, 9, 10], dtype="float32") + updates = paddle.rand(shape=[3, 2, 9, 10], dtype="float32") + index = paddle.randint(0, 3, shape=[3, 2, 3]) + verify_model(scatter_nd_add, [x, index, updates]) + + @tvm.testing.uses_gpu def test_forward_slice(): @paddle.jit.to_static - def slice1(inputs): - return inputs[:, :, :, :3] + def slice1(inputs, end): + return inputs[:, :, :, :end] @paddle.jit.to_static def slice2(inputs): @@ -607,55 +1784,442 @@ def slice4(inputs): x1 = paddle.to_tensor([3]) + paddle.to_tensor([1]) return inputs[:, x0:, 1:x1, :] + @paddle.jit.to_static + def slice5(inputs): + x0 = paddle.to_tensor([3]) + return inputs[:, 1::1, 2::x0, 4:10] + input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model( - slice1, - input_data=[ - input_data, - ], - ) + end = paddle.to_tensor(np.array([3])) + verify_model(slice1, [input_data, end], input_shape=[[1, 3, 10, 10], [1]]) verify_model(slice2, input_data=input_data) - # need op "strided_slice" - # verify_model(slice3, input_data=paddle.randn((4, 4))) - # need op "assign_value" - # verify_model(slice4, input_data=input_data) + verify_model(slice3, input_data=paddle.randn((4, 4))) + verify_model(slice4, input_data=input_data) + verify_model(slice5, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_sort(): + @paddle.jit.to_static + def sort(inputs): + return paddle.sort(inputs) + + @paddle.jit.to_static + def sort2(inputs): + return paddle.sort(inputs, axis=0, descending=True) + + input_shape = [2, 3, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(sort, input_data) + input_data2 = np.random.randint(100, size=input_shape) + verify_model(sort2, input_data2) + + +@tvm.testing.uses_gpu +def test_forward_split(): + @paddle.jit.to_static + def split(inputs): + return paddle.split(inputs, 2, axis=paddle.to_tensor([0], "int32")) + + @paddle.jit.to_static + def split2(inputs): + return paddle.split(inputs, [1, 2, -1], axis=1) + + @paddle.jit.to_static + def split3(inputs): + return paddle.split(inputs, [paddle.to_tensor([2]), 1, paddle.to_tensor(3)], axis=0) + + input_shape = [6, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(split, input_data=input_data) + verify_model(split2, input_data=input_data) + verify_model(split3, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_squeeze(): + @paddle.jit.to_static + def squeeze(inputs): + return paddle.squeeze(inputs) + + @paddle.jit.to_static + def squeeze2(inputs): + return paddle.squeeze(inputs, axis=0) + + @paddle.jit.to_static + def squeeze3(inputs): + return paddle.squeeze(inputs, axis=[0, -1]) + + input_shape = [1, 2, 1, 3, 1] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(squeeze, input_data=input_data) + verify_model(squeeze2, input_data=input_data) + verify_model(squeeze3, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_stack(): + @paddle.jit.to_static + def stack(input1, input2, input3): + return paddle.stack([input1, input2, input3]) + + @paddle.jit.to_static + def stack2(input1, input2, input3): + return paddle.stack([input1, input2, input3], axis=-1) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + input_data2 = paddle.rand(input_shape, dtype="float32") + input_data3 = paddle.rand(input_shape, dtype="float32") + verify_model(stack, input_data=[input_data, input_data2, input_data3]) + verify_model(stack2, input_data=[input_data, input_data2, input_data3]) + + +@tvm.testing.uses_gpu +def test_forward_std(): + class Std1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, 1, unbiased=False) + + class Std2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=-2, keepdim=False, unbiased=False) + + class Std3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=3, keepdim=True, unbiased=False) + + class Std4(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=[2, 3], keepdim=True, unbiased=False) + + class Std5(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, axis=[2, 3], keepdim=False, unbiased=False) + + class Std6(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, unbiased=False) + + class Std7(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.std(inputs, unbiased=True) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Std1(), input_data=input_data) + verify_model(Std2(), input_data=input_data) + verify_model(Std3(), input_data=input_data) + verify_model(Std4(), input_data=input_data) + verify_model(Std5(), input_data=input_data) + verify_model(Std6(), input_data=input_data) + verify_model(Std7(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_subtract(): + class Subtract(nn.Layer): + @paddle.jit.to_static + def forward(self, x, y): + return paddle.subtract(x, y) + + input_data1 = paddle.to_tensor([2, np.nan, 5], dtype="float32") + input_data2 = paddle.to_tensor([1, 4, np.nan], dtype="float32") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + input_data1 = paddle.randint(0, 10, (3, 4), dtype="int32") + input_data2 = paddle.randint(0, 10, (4,), dtype="int32") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + input_data1 = paddle.randint(0, 10, (10, 3, 4), dtype="int64") + input_data2 = paddle.randint(0, 10, (3, 4), dtype="int64") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_t(): + class T(nn.Layer): + def forward(self, x): + return paddle.t(x) + + input_data1 = paddle.randn((3, 4), dtype="float32") + verify_model(T(), input_data=[input_data1]) + + +@tvm.testing.uses_gpu +def test_forward_topk(): + class Topk1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.topk(inputs, k=3) + + class Topk2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.topk(inputs, k=3, axis=-2) + + class Topk3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.topk(inputs, k=3, axis=3) + + class Topk4(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.topk(inputs, k=3, largest=True) + + class Topk5(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.topk(inputs, k=3, largest=False) + + class Topk6(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.topk(inputs, k=3, sorted=True) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Topk1(), input_data=input_data) + verify_model(Topk2(), input_data=input_data) + verify_model(Topk3(), input_data=input_data) + verify_model(Topk4(), input_data=input_data) + verify_model(Topk5(), input_data=input_data) + verify_model(Topk6(), input_data=input_data) @tvm.testing.uses_gpu -def test_forward_tanh(): +def test_forward_tile(): @paddle.jit.to_static - def tanh(inputs): - return paddle.tanh(inputs) + def tile(inputs, shape): + return paddle.tile(inputs, shape) + + @paddle.jit.to_static + def tile2(inputs, inputs2): + inputs2 = paddle.shape(inputs2) + inputs2 = paddle.cast(inputs2, "int32") + return paddle.tile(inputs, inputs2) + + @paddle.jit.to_static + def tile3(inputs, inputs2): + inputs2 = paddle.shape(inputs2)[0] + inputs2 = paddle.cast(inputs2, "int32") + return paddle.tile(inputs, [inputs2, 3]) + + input_shape = [2, 2] + input_data = paddle.rand(input_shape, dtype="float32") + input_data2 = paddle.rand([1, 2], dtype="float32") + shape = paddle.to_tensor(np.array([3, 2], "int32")) + verify_model(tile, [input_data, shape], input_shape=[[2, 2], [2]]) + verify_model(tile2, input_data=[input_data, input_data2]) + verify_model(tile3, input_data=[input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_unstack(): + @paddle.jit.to_static + def unstack1(x): + return paddle.unstack(x) + + @paddle.jit.to_static + def unstack2(x): + return paddle.unstack(x, axis=-1) + + @paddle.jit.to_static + def unstack3(x): + return paddle.unstack(x, axis=-1, num=3) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(unstack1, input_data=[input_data]) + verify_model(unstack2, input_data=[input_data]) + verify_model(unstack3, input_data=[input_data]) + + +@tvm.testing.uses_gpu +def test_forward_unique(): + @paddle.jit.to_static + def unique1(x): + return paddle.unique(x) + + @paddle.jit.to_static + def unique2(x): + return paddle.unique(x, return_index=True, return_inverse=False, return_counts=False) + + @paddle.jit.to_static + def unique3(x): + return paddle.unique(x, return_index=False, return_inverse=True, return_counts=False) + + @paddle.jit.to_static + def unique4(x): + return paddle.unique(x, return_index=False, return_inverse=False, return_counts=True) + + @paddle.jit.to_static + def unique5(x): + return paddle.unique(x, return_index=True, return_inverse=True, return_counts=False) + + @paddle.jit.to_static + def unique6(x): + return paddle.unique(x, return_index=False, return_inverse=True, return_counts=True) + + @paddle.jit.to_static + def unique7(x): + return paddle.unique(x, return_index=True, return_inverse=False, return_counts=True) + + @paddle.jit.to_static + def unique8(x): + return paddle.unique(x, return_index=True, return_inverse=True, return_counts=True) + + input_data = np.array([2, 3, 3, 1, 5, 3]) + input_data = paddle.to_tensor(input_data) + verify_model(unique1, input_data=[input_data], input_shape=[[6]]) + verify_model(unique2, input_data=[input_data], input_shape=[[6]]) + verify_model(unique3, input_data=[input_data], input_shape=[[6]]) + verify_model(unique4, input_data=[input_data], input_shape=[[6]]) + verify_model(unique5, input_data=[input_data], input_shape=[[6]]) + verify_model(unique6, input_data=[input_data], input_shape=[[6]]) + verify_model(unique7, input_data=[input_data], input_shape=[[6]]) + verify_model(unique8, input_data=[input_data], input_shape=[[6]]) + + +@tvm.testing.uses_gpu +def test_forward_zeros(): + @paddle.jit.to_static + def zeros1(inputs): + zeros = paddle.zeros([1, 3, 10, 10]) + out = inputs + zeros + return out + + @paddle.jit.to_static + def zeros2(inputs): + shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32") + zeros = paddle.zeros(shape) + out = inputs + zeros + return out input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(tanh, input_data=input_data) + verify_model(zeros1, input_data=input_data) + verify_model(zeros2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_where(): + class Where(nn.Layer): + @paddle.jit.to_static + def forward(self, c, x, y): + return paddle.where(c, x, y) + + x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2]) + y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0]) + verify_model(Where(), [x < 1, x, y]) + input_shape = [1, 3, 10, 10] + x = paddle.rand(input_shape, dtype="float32") + y = paddle.rand(input_shape, dtype="float32") + verify_model(Where(), [x < y, x, y]) + + +@tvm.testing.uses_gpu +def test_forward_while(): + class While(nn.Layer): + def __init__(self): + super(While, self).__init__() + + def forward(self, x): + s = paddle.shape(x) + i = paddle.slice(s, axes=[0], starts=[0], ends=[1]) + y = paddle.to_tensor(np.array([5]).astype("int32")) + while i < y: + i *= np.array([3], dtype="int32") + return i + + input_data1 = paddle.rand([1, 3, 224, 224], dtype="float32") + verify_model(While(), input_data=[input_data1], input_shape=[[-1, 3, -1, -1]]) if __name__ == "__main__": test_forward_add_subtract() + test_forward_addmm() + test_forward_addn() + test_forward_arange() test_forward_argmax() + test_forward_argmin() + test_forward_argsort() test_forward_assign() test_forward_batch_norm() test_forward_cast() + test_forward_clip() test_forward_concat_unsqueeze() - test_forward_cumsum() test_forward_conv() + test_forward_crop() + test_forward_cumsum() + test_forward_dist() + test_forward_dot() test_forward_dropout() + test_forward_elemwise() + test_forward_expand() + test_forward_expand_as() + test_forward_flatten() test_forward_shape_full() + test_forward_ones() test_forward_ones_like() + test_forward_gather_assign_value() + test_forward_gather_nd() test_forward_gelu() - test_forward_hard_sigmoid() - test_forward_hard_swish() + test_forward_group_norm() + test_forward_math() + test_forward_activation() + test_forward_index_select() + test_forward_instance_norm() + test_forward_interpolate() + test_forward_isinf() test_forward_layer_norm() test_forward_leaky_relu() + test_forward_logical_op() test_forward_look_up() - test_forward_multiply() + test_forward_lstm() + test_forward_gru() + test_forward_simplernn() + test_forward_masked_select() test_forward_matmul() + test_forward_meshgrid() + test_forward_mm() + test_forward_mv() + test_forward_multiply() + test_forward_nonzero() + test_forward_norm() test_forward_pool2d() - test_forward_relu() + test_forward_pad() + test_forward_pixel_shuffle() + test_forward_prelu() + test_forward_pow() + test_forward_rank() + test_forward_reduce_op() test_forward_reshape() test_forward_scale() + test_forward_scatter() + test_forward_scatter_nd() test_forward_slice() - test_forward_tanh() + test_forward_sort() + test_forward_split() + test_forward_squeeze() + test_forward_std() + test_forward_subtract() + test_forward_t() + test_forward_topk() + test_forward_tile() + test_forward_conv_transpose() + test_forward_unstack() + test_forward_unique() + test_forward_math() + test_forward_zeros() + test_forward_where() + test_forward_while()