From f4fad89210a52a3fda556de4d9999efb3c77a2c9 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Wed, 27 Oct 2021 08:49:35 +0000 Subject: [PATCH 1/5] update ci-gpu to v0.78 --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 6a7a1f4e3d36..6f3364bae3f8 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -46,7 +46,7 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = "tlcpack/ci-lint:v0.67" -ci_gpu = "tlcpack/ci-gpu:v0.77" +ci_gpu = "tlcpack/ci-gpu:v0.78" ci_cpu = "tlcpack/ci-cpu:v0.78" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.74" From 940ed587d292b951333cd374bdb02cbf00a28ec9 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 29 Oct 2021 04:03:52 +0000 Subject: [PATCH 2/5] add some common operators --- python/tvm/relay/frontend/paddlepaddle.py | 195 ++++++++++++++++++++++ 1 file changed, 195 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index a568be469e98..d36e3a74078b 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -115,6 +115,32 @@ def convert_binary_logical_op(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_addmm(g, op, block): + """Operator converter for addmm.""" + + input_x = g.get_node(op.input("Input")[0]) + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + alpha = op.attr("Alpha") + beta = op.attr("Beta") + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + + if not isinstance(alpha, _expr.Expr) and alpha != 1: + alpha = _expr.const(alpha, dtype) + x *= alpha + + if not isinstance(beta, _expr.Expr) and beta != 1: + beta = _expr.const(beta, dtype) + input_x *= beta + + transposed_y = _op.transpose(y, axes=[1, 0]) + dense_out = _op.nn.dense(x, transposed_y) + out = dense_out + input_x + g.add_node(op.output("Out")[0], out) + + def convert_arg_max_min(g, op, block): """Operator converter for arg_max and arg_min.""" @@ -192,6 +218,26 @@ def convert_batch_norm(g, op, block): g.add_node(op.output("Y")[0], out[0]) +def convert_bmm(g, op, block): + """Operator converter for bmm.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + y = _op.transpose(y, [0, 2, 1]) + out = _op.nn.batch_matmul(x, y) + g.add_node(op.output("Out")[0], out) + + +def convert_brelu(g, op, block): + """Operator converter for brelu.""" + + x = g.get_node(op.input("X")[0]) + t_max = op.attr("t_max") + t_min = op.attr("t_min") + out = _op.tensor.clip(x, t_min, t_max) + g.add_node(op.output("Out")[0], out) + + def convert_cast(g, op, block): """Operator converter for cast.""" @@ -413,6 +459,29 @@ def convert_fill_constant(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_gather(g, op, block): + """Operator converter for gather.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + axis = op.attr("axis") + out = _op.take(x, index, axis) + g.add_node(op.output("Out")[0], out) + + +def convert_gather_nd(g, op, block): + """Operator converter for gather_nd.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + shape = infer_shape(index) + perm = list(range(0, len(shape) - 1)) + perm.insert(0, len(shape) - 1) + index = _op.transpose(index, axes=perm) + out = _op.gather_nd(x, index, 0, shape[-1]) + g.add_node(op.output("Out")[0], out) + + def convert_gelu(g, op, block): """Operator converter for gelu.""" @@ -424,6 +493,39 @@ def convert_gelu(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_group_norm(g, op, block): + """Operator converter for group_norm.""" + + x = g.get_node(op.input("X")[0]) + num_groups = op.attr("groups") + epsilon = op.attr("epsilon") + gamma = g.get_node(op.input("Scale")[0]) + beta = g.get_node(op.input("Bias")[0]) + out = _op.nn.group_norm( + x, + gamma=gamma, + beta=beta, + num_groups=num_groups, + axis=1, + epsilon=epsilon, + center=True, + scale=True, + ) + g.add_node(op.output("Y")[0], out) + + +def convert_hard_shrink(g, op, block): + """Operator converter for hard_shrink.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + threshold = op.attr("threshold") + threshold = _op.const(threshold, dtype) + out = _op.logical_or(x < _op.const(-1.0, dtype) * threshold, x > threshold) + out = _op.cast(out, dtype) * x + g.add_node(op.output("Out")[0], out) + + def convert_hard_sigmoid(g, op, block): """Operator converter for hard_sigmoid.""" @@ -490,6 +592,15 @@ def convert_leaky_relu(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_logical_not(g, op, block): + """Operator converter for logical_not op.""" + + ipt0 = g.get_node(op.input("X")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0) + g.add_node(op.output("Out")[0], out) + + def convert_lookup_table(g, op, block): """Operator converter for lookup_table_v2.""" @@ -763,6 +874,15 @@ def convert_pool2d(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_reciprocal(g, op, block): + """Operator converter for reciprocal.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = _expr.const(1.0, dtype) / x + g.add_node(op.output("Out")[0], out) + + def convert_reduce(g, op, block): """Operator converter for series of reduce operators.""" @@ -790,6 +910,14 @@ def convert_reduce(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_relu6(g, op, block): + """Operator converter for relu6.""" + + x = g.get_node(op.input("X")[0]) + out = _op.clip(x, 0.0, 6.0) + g.add_node(op.output("Out")[0], out) + + def convert_reshape(g, op, block): """Operator converter for reshape.""" @@ -843,6 +971,40 @@ def convert_scale(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_scatter(g, op, block): + """Operator converter for scatter.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Ids")[0]) + updates = g.get_node(op.input("Updates")[0]) + overwrite = op.attr("overwrite") + + shape = infer_shape(updates) + ndims = len(shape) + index = _op.expand_dims(index, axis=-1, num_newaxis=ndims - 1) + index = _op.transform.broadcast_to(index, shape) + + if overwrite: + out = _op.scatter(x, index, updates, axis=0) + else: + out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0) + out += _op.scatter(x, index, _op.zeros_like(updates), axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_scatter_nd_add(g, op, block): + """Operator converter for scatter_nd_add.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + updates = g.get_node(op.input("Updates")[0]) + indices_dim = len(infer_shape(index)) + axes = list(range(indices_dim)) + index = _op.transpose(index, axes[-1:] + axes[:-1]) + out = _op.scatter_nd(x, index, updates, mode="add") + g.add_node(op.output("Out")[0], out) + + def convert_shape(g, op, block): """Operator converter for shape.""" @@ -905,6 +1067,16 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_square(g, op, block): + """Operator converter for square.""" + + x = g.get_node(op.input("X")[0]) + dtype = block.var(op.output("Out")[0]).dtype + dtype = _convert_dtype_value(dtype) + out = _op.pow(x, _expr.const(2, dtype)) + g.add_node(op.output("Out")[0], out) + + def convert_squeeze(g, op, block): """Operator converter for squeeze2.""" @@ -929,6 +1101,7 @@ def convert_unsqueeze(g, op, block): _convert_map = { "abs": convert_unary_op, "acos": convert_unary_op, + "addmm": convert_addmm, "arg_max": convert_arg_max_min, "arg_min": convert_arg_max_min, "argsort": convert_argsort, @@ -937,6 +1110,8 @@ def convert_unsqueeze(g, op, block): "assign_value": convert_assign_value, "atan": convert_unary_op, "batch_norm": convert_batch_norm, + "bmm": convert_bmm, + "brelu": convert_brelu, "cast": convert_cast, "ceil": convert_unary_op, "concat": convert_concat, @@ -949,7 +1124,12 @@ def convert_unsqueeze(g, op, block): "dropout": convert_dropout, "elementwise_add": convert_elementwise_op, "elementwise_div": convert_elementwise_op, + "elementwise_floordiv": convert_elementwise_op, + "elementwise_max": convert_elementwise_op, + "elementwise_min": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, + "elementwise_pow": convert_elementwise_op, + "elementwise_prod": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, "equal": convert_elementwise_op, "erf": convert_unary_op, @@ -960,7 +1140,14 @@ def convert_unsqueeze(g, op, block): "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, "floor": convert_unary_op, + "floor_mod": convert_elementwise_op, + "gather": convert_gather, + "gather_nd": convert_gather_nd, "gelu": convert_gelu, + "greater_equal": convert_elementwise_op, + "greater_than": convert_elementwise_op, + "group_norm": convert_group_norm, + "hard_shrink": convert_hard_shrink, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, "isfinite_v2": convert_unary_op, @@ -974,17 +1161,22 @@ def convert_unsqueeze(g, op, block): "log2": convert_unary_op, "log10": convert_unary_op, "logical_and": convert_binary_logical_op, + "logical_not": convert_logical_not, "logical_or": convert_binary_logical_op, "logical_xor": convert_binary_logical_op, "lookup_table_v2": convert_lookup_table, "matmul": convert_matmul, "matmul_v2": convert_matmul, "mul": convert_mul, + "pad1d": convert_padding, + "pad2d": convert_padding, "pad3d": convert_padding, "pool2d": convert_pool2d, "relu": convert_unary_op, + "relu6": convert_relu6, "reshape2": convert_reshape, "round": convert_unary_op, + "reciprocal": convert_reciprocal, "reduce_all": convert_reduce, "reduce_any": convert_reduce, "reduce_max": convert_reduce, @@ -994,6 +1186,8 @@ def convert_unsqueeze(g, op, block): "reduce_mean": convert_reduce, "rsqrt": convert_unary_op, "scale": convert_scale, + "scatter": convert_scatter, + "scatter_nd_add": convert_scatter_nd_add, "shape": convert_shape, "sigmoid": convert_unary_op, "sign": convert_unary_op, @@ -1002,6 +1196,7 @@ def convert_unsqueeze(g, op, block): "slice": convert_slice, "softmax": convert_softmax, "sqrt": convert_unary_op, + "square": convert_square, "squeeze2": convert_squeeze, "tan": convert_unary_op, "tanh": convert_unary_op, From 7c49ab9914da8fe45c82c30676ea51ecf8cf2f50 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 29 Oct 2021 06:02:47 +0000 Subject: [PATCH 3/5] code format --- python/tvm/relay/frontend/paddlepaddle.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index d36e3a74078b..bc21fbc0b74a 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -125,7 +125,7 @@ def convert_addmm(g, op, block): alpha = op.attr("Alpha") beta = op.attr("Beta") dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] + dtype = convert_dtype_value(dtype) if not isinstance(alpha, _expr.Expr) and alpha != 1: alpha = _expr.const(alpha, dtype) @@ -1073,7 +1073,7 @@ def convert_square(g, op, block): x = g.get_node(op.input("X")[0]) dtype = block.var(op.output("Out")[0]).dtype dtype = _convert_dtype_value(dtype) - out = _op.pow(x, _expr.const(2, dtype)) + out = _op.power(x, _expr.const(2, dtype)) g.add_node(op.output("Out")[0], out) @@ -1101,7 +1101,7 @@ def convert_unsqueeze(g, op, block): _convert_map = { "abs": convert_unary_op, "acos": convert_unary_op, - "addmm": convert_addmm, + "addmm": convert_addmm, "arg_max": convert_arg_max_min, "arg_min": convert_arg_max_min, "argsort": convert_argsort, @@ -1110,8 +1110,8 @@ def convert_unsqueeze(g, op, block): "assign_value": convert_assign_value, "atan": convert_unary_op, "batch_norm": convert_batch_norm, - "bmm": convert_bmm, - "brelu": convert_brelu, + "bmm": convert_bmm, + "brelu": convert_brelu, "cast": convert_cast, "ceil": convert_unary_op, "concat": convert_concat, @@ -1140,14 +1140,14 @@ def convert_unsqueeze(g, op, block): "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, "floor": convert_unary_op, - "floor_mod": convert_elementwise_op, + "floor_mod": convert_elementwise_op, "gather": convert_gather, "gather_nd": convert_gather_nd, "gelu": convert_gelu, "greater_equal": convert_elementwise_op, "greater_than": convert_elementwise_op, - "group_norm": convert_group_norm, - "hard_shrink": convert_hard_shrink, + "group_norm": convert_group_norm, + "hard_shrink": convert_hard_shrink, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, "isfinite_v2": convert_unary_op, @@ -1173,10 +1173,10 @@ def convert_unsqueeze(g, op, block): "pad3d": convert_padding, "pool2d": convert_pool2d, "relu": convert_unary_op, - "relu6": convert_relu6, + "relu6": convert_relu6, "reshape2": convert_reshape, "round": convert_unary_op, - "reciprocal": convert_reciprocal, + "reciprocal": convert_reciprocal, "reduce_all": convert_reduce, "reduce_any": convert_reduce, "reduce_max": convert_reduce, @@ -1196,7 +1196,7 @@ def convert_unsqueeze(g, op, block): "slice": convert_slice, "softmax": convert_softmax, "sqrt": convert_unary_op, - "square": convert_square, + "square": convert_square, "squeeze2": convert_squeeze, "tan": convert_unary_op, "tanh": convert_unary_op, From 5a5a508095663e74e3cd0ca2b7cb25e672866140 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 29 Oct 2021 06:17:53 +0000 Subject: [PATCH 4/5] add transpose and swish --- python/tvm/relay/frontend/paddlepaddle.py | 81 +++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index bc21fbc0b74a..e11086b1f083 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -874,6 +874,18 @@ def convert_pool2d(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_pow(g, op, block): + """Operator converter for pow.""" + + x = g.get_node(op.input("X")[0]) + dtype = block.var(op.output("Out")[0]).dtype + dtype = _convert_dtype_value(dtype) + factor = op.attr("factor") + factor = _expr.const(factor, dtype=dtype) + out = _op.power(x, factor) + g.add_node(op.output("Out")[0], out) + + def convert_reciprocal(g, op, block): """Operator converter for reciprocal.""" @@ -1005,6 +1017,22 @@ def convert_scatter_nd_add(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_selu(g, op, block): + """Operator converter for selu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(op.attr("alpha"), dtype) + scale = _op.const(op.attr("scale"), dtype) + out = ( + _expr.const(-1.0, dtype=dtype) + * alpha + * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(x)) + ) + out = scale * (out + _op.nn.relu(x)) + g.add_node(op.output("Out")[0], out) + + def convert_shape(g, op, block): """Operator converter for shape.""" @@ -1013,6 +1041,15 @@ def convert_shape(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_size(g, op, block): + """Operator converter for size.""" + + input_x = g.get_node(op.input("Input")[0]) + out = _op.ndarray_size(input_x, dtype="int64") + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + def convert_slice(g, op, block): """Operator converter for slice.""" @@ -1067,6 +1104,26 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_softplus(g, op, block): + """Operator converter for softplus.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + beta = op.attr("beta") + beta = _expr.const(beta, dtype=dtype) + out = _op.log(_op.exp(x * beta) + _expr.const(1.0, dtype=dtype)) / beta + g.add_node(op.output("Out")[0], out) + + +def convert_softsign(g, op, block): + """Operator converter for softsign.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = x / (_op.const(1.0, dtype) + _op.abs(x)) + g.add_node(op.output("Out")[0], out) + + def convert_square(g, op, block): """Operator converter for square.""" @@ -1088,6 +1145,23 @@ def convert_squeeze(g, op, block): g.add_node(op.output("Out")[0], x) +def convert_swish(g, op, block): + """Operator converter for swish.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = x / (_op.const(1.0, dtype) + _op.exp(_op.const(-1.0, dtype) * x)) + g.add_node(op.output("Out")[0], out) + + +def convert_transpose(g, op, block): + """Operator converter for transpose.""" + + perm = op.attr("axis") + out = _op.transpose(g.get_node(op.input("X")[0]), axes=perm) + g.add_node(op.output("Out")[0], out) + + def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -1172,6 +1246,7 @@ def convert_unsqueeze(g, op, block): "pad2d": convert_padding, "pad3d": convert_padding, "pool2d": convert_pool2d, + "pow": convert_pow, "relu": convert_unary_op, "relu6": convert_relu6, "reshape2": convert_reshape, @@ -1188,18 +1263,24 @@ def convert_unsqueeze(g, op, block): "scale": convert_scale, "scatter": convert_scatter, "scatter_nd_add": convert_scatter_nd_add, + "selu": convert_selu, "shape": convert_shape, "sigmoid": convert_unary_op, "sign": convert_unary_op, "sin": convert_unary_op, "sinh": convert_unary_op, + "size": convert_size, "slice": convert_slice, "softmax": convert_softmax, + "softplus": convert_softplus, + "softsign": convert_softsign, "sqrt": convert_unary_op, "square": convert_square, "squeeze2": convert_squeeze, + "swish": convert_swish, "tan": convert_unary_op, "tanh": convert_unary_op, + "transpose2": convert_transpose, "unsqueeze2": convert_unsqueeze, } From 6648df0bdd84c4a8c2cd20648776b59948bf5fd5 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 2 Nov 2021 13:27:53 +0000 Subject: [PATCH 5/5] add unitest --- python/tvm/relay/frontend/paddlepaddle.py | 4 +- .../frontend/paddlepaddle/test_forward.py | 250 +++++++++++++++++- 2 files changed, 246 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index e11086b1f083..fa7c80c912d9 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -125,7 +125,7 @@ def convert_addmm(g, op, block): alpha = op.attr("Alpha") beta = op.attr("Beta") dtype = block.var(op.output("Out")[0]).dtype - dtype = convert_dtype_value(dtype) + dtype = _convert_dtype_value(dtype) if not isinstance(alpha, _expr.Expr) and alpha != 1: alpha = _expr.const(alpha, dtype) @@ -1201,6 +1201,7 @@ def convert_unsqueeze(g, op, block): "elementwise_floordiv": convert_elementwise_op, "elementwise_max": convert_elementwise_op, "elementwise_min": convert_elementwise_op, + "elementwise_mod": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, "elementwise_pow": convert_elementwise_op, "elementwise_prod": convert_elementwise_op, @@ -1242,6 +1243,7 @@ def convert_unsqueeze(g, op, block): "matmul": convert_matmul, "matmul_v2": convert_matmul, "mul": convert_mul, + "not_equal": convert_elementwise_op, "pad1d": convert_padding, "pad2d": convert_padding, "pad3d": convert_padding, diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 513c4b206c4d..b8d4c1150238 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -125,6 +125,33 @@ def add_subtract3(inputs1, inputs2): verify_model(add_subtract3, [input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_addmm(): + class Addmm(nn.Layer): + def __init__(self, alpha=1.0, beta=1.0): + super(Addmm, self).__init__() + self.alpha = alpha + self.beta = beta + + @paddle.jit.to_static + def forward(self, inputs, x, y): + return paddle.addmm(inputs, x, y, self.alpha, self.beta) + + input_shapes = [[10, 10], [1, 1], [7, 1]] + x_shapes = [[10, 3], [5, 6], [7, 7]] + y_shapes = [[3, 10], [6, 2], [7, 3]] + input_shapes = [[10, 10]] + x_shapes = [[10, 3]] + y_shapes = [[3, 10]] + + for i in range(len(input_shapes)): + input_data = paddle.rand(input_shapes[i], dtype="float32") + x_data = paddle.rand(x_shapes[i], dtype="float32") + y_data = paddle.rand(y_shapes[i], dtype="float32") + verify_model(Addmm(), input_data=[input_data, x_data, y_data]) + verify_model(Addmm(0.5, 0.3), input_data=[input_data, x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_arg_max_min(): class ArgMax(nn.Layer): @@ -279,6 +306,24 @@ def forward(self, input_data): verify_model(BatchNorm3D(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_bmm(): + class Bmm(nn.Layer): + def __init__(self): + super(Bmm, self).__init__() + + @paddle.jit.to_static + def forward(self, x, y): + return paddle.bmm(x, y) + + x_shapes = [[10, 3, 4], [5, 6, 2], [1, 7, 7]] + y_shapes = [[10, 4, 5], [5, 2, 7], [1, 7, 3]] + for i in range(len(x_shapes)): + x_data = paddle.rand(x_shapes[i], dtype="float32") + y_data = paddle.rand(y_shapes[i], dtype="float32") + verify_model(Bmm(), input_data=[x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_cast(): @paddle.jit.to_static @@ -461,15 +506,25 @@ def forward(self, input1, input2): api_list = [ "equal", + "floor_divide", + "greater_equal", + "greater_than", "less_equal", "less_than", + "maximum", + "minimum", + "pow", ] x_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] y_shapes = [[1], [8, 20], [4, 1, 1], [2, 3, 8, 8], [2, 3, 3, 9, 1]] for x_shape, y_shape in zip(x_shapes, y_shapes): - x_data = paddle.randint(1, 1000, x_shape, dtype="int32") - y_data = paddle.randint(1, 1000, y_shape, dtype="int32") + x_data = paddle.randint(1, 10, x_shape, dtype="int32") + y_data = paddle.randint(1, 10, y_shape, dtype="int32") for api_name in api_list: + if api_name == "pow": + # only support float for pow + x_data = x_data.astype("float32") + y_data = y_data.astype("float32") verify_model(ElemwiseAPI(api_name), [x_data, y_data]) @@ -530,6 +585,100 @@ def forward(self, x, y): verify_model(ExpandAs(), [x_data, y_data]) +@tvm.testing.uses_gpu +def test_forward_gather(): + class Gather(nn.Layer): + def __init__(self, axis=None): + super(Gather, self).__init__() + self.axis = axis + + @paddle.jit.to_static + def forward(self, x, index): + return paddle.gather(x, index, axis=self.axis) + + x_shapes = [[20, 10], [10, 10, 8]] + index = paddle.to_tensor(np.array([1, 3, 5]).astype("int64")) + for x_shape in x_shapes: + x_data = paddle.rand(x_shape, dtype="float32") + verify_model(Gather(), [x_data, index]) + verify_model(Gather(axis=0), [x_data, index]) + verify_model(Gather(axis=1), [x_data, index]) + + +@tvm.testing.uses_gpu +def test_forward_gather_nd(): + class GatherNd(nn.Layer): + @paddle.jit.to_static + def forward(self, x, index): + return paddle.gather_nd(x, index) + + x_shapes = [[20], [8, 8], [4, 5, 6], [3, 4, 3, 5]] + y_shapes = [[2, 1], [2], [1, 2, 3], [3]] + for x_shape, y_shape in zip(x_shapes, y_shapes): + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.randint(low=0, high=3, shape=y_shape, dtype="int64") + verify_model(GatherNd(), [x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_group_norm(): + class GroupNorm(nn.Layer): + def __init__(self, channels, groups): + super(GroupNorm, self).__init__() + self.group_norm = paddle.nn.GroupNorm(num_channels=channels, num_groups=groups) + + def forward(self, inputs): + return self.group_norm(inputs) + + input_shapes = [[1, 4, 6, 6], [2, 2, 4, 7], [2, 8, 1, 1]] + for input_shape in input_shapes: + num_channels = input_shape[1] + input_data = paddle.uniform(input_shape) + verify_model(GroupNorm(num_channels, 1), input_data) + verify_model(GroupNorm(num_channels, 2), input_data) + + +@tvm.testing.uses_gpu +def test_forward_scatter(): + class Scatter(nn.Layer): + def __init__(self, overwrite=True): + super(Scatter, self).__init__() + self.overwrite = overwrite + + @paddle.jit.to_static + def forward(self, x, index, updates): + return paddle.scatter(x, index, updates, overwrite=self.overwrite) + + x_shapes = [[10], [4, 5], [6, 4, 5], [4, 5, 6, 4]] + index_shapes = [[10], [4], [6], [4]] + for x_shape, index_shape in zip(x_shapes, index_shapes): + x_data = paddle.rand(x_shape, dtype="float32") + updates = paddle.rand(x_shape, dtype="float32") + 1.0 + index = paddle.randint(low=0, high=3, shape=index_shape) + verify_model(Scatter(), [x_data, index, updates]) + verify_model(Scatter(False), [x_data, index, updates]) + + +def test_forward_scatter_nd(): + @paddle.jit.to_static + def scatter_nd(index, updates): + shape = [3, 5, 9, 10] + return paddle.scatter_nd(index, updates, shape) + + @paddle.jit.to_static + def scatter_nd_add(x, index, updates): + return paddle.scatter_nd_add(x, index, updates) + + index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) + index = paddle.to_tensor(index_data) + updates = paddle.rand(shape=[3, 9, 10], dtype="float32") + verify_model(scatter_nd, [index, updates]) + x = paddle.rand(shape=[3, 5, 4, 9, 10], dtype="float32") + updates = paddle.rand(shape=[3, 2, 9, 10], dtype="float32") + index = paddle.randint(0, 3, shape=[3, 2, 3]) + verify_model(scatter_nd_add, [x, index, updates]) + + @tvm.testing.uses_gpu def test_forward_shape_full(): @paddle.jit.to_static @@ -676,6 +825,22 @@ def forward(self, x, y): verify_model(LogicalAPI("logical_xor"), [x_data, y_data]) +@tvm.testing.uses_gpu +def test_forward_logical_not(): + class LogicalNot(nn.Layer): + def __init__(self): + super(LogicalNot, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + return paddle.logical_not(x).astype("int32") + + input_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] + for input_shape in input_shapes: + input_data = paddle.randint(-2, 2, input_shape).astype("bool") + verify_model(LogicalNot(), input_data) + + @tvm.testing.uses_gpu def test_forward_look_up(): @paddle.jit.to_static @@ -780,6 +945,48 @@ def forward(self, inputs): verify_model(Pool2D3(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_pad1d(): + class Pad1D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCL"): + super(Pad1D, self).__init__() + self.pad1d = paddle.nn.Pad1D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad1d(inputs) + + input_shapes = [[1, 2, 5], [2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad1D(padding=2), input_data=input_data) + verify_model(Pad1D(padding=[1, 2], data_format="NLC"), input_data=input_data) + verify_model(Pad1D(padding=[0, 2], value=0.3), input_data=input_data) + verify_model(Pad1D(padding=[2, 2], mode="reflect"), input_data=input_data) + verify_model(Pad1D(padding=3, mode="replicate"), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_pad2d(): + class Pad2D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCHW"): + super(Pad2D, self).__init__() + self.pad2d = paddle.nn.Pad2D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad2d(inputs) + + input_shapes = [[1, 2, 5, 5], [2, 2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad2D(padding=2), input_data=input_data) + verify_model(Pad2D(padding=[1, 2, 0, 2], data_format="NHWC"), input_data=input_data) + verify_model(Pad2D(padding=[1, 2, 0, 2], value=0.3), input_data=input_data) + verify_model(Pad2D(padding=[1, 2, 0, 2], mode="reflect"), input_data=input_data) + verify_model(Pad2D(padding=3, mode="replicate"), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_pad3d(): class Pad3D(nn.Layer): @@ -794,11 +1001,28 @@ def forward(self, inputs): input_shapes = [[1, 2, 2, 5, 5], [1, 2, 2, 5, 9]] for input_shape in input_shapes: input_data = paddle.rand(input_shape, dtype="float32") - verify_model(Pad3D(padding=2), input_data=input_data) - verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1]), input_data=input_data) - verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], value=0.3), input_data=input_data) - verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], mode="reflect"), input_data=input_data) - verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) + verify_model(Pad3D(padding=2), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], data_format="NDHWC"), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], value=0.3), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], mode="reflect"), input_data=input_data) + verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_transpose(): + class Transpose(nn.Layer): + def __init__(self, perm): + super(Transpose, self).__init__() + self.perm = perm + + @paddle.jit.to_static + def forward(self, inputs): + inputs = inputs + inputs.size() + return paddle.transpose(inputs, perm=self.perm) + + input_data = paddle.rand([1, 3, 5, 4, 3], dtype="float32") + verify_model(Transpose([0, 1, 2, 3, 4]), input_data=input_data) + verify_model(Transpose([4, 3, 2, 0, 1]), input_data=input_data) @tvm.testing.uses_gpu @@ -938,17 +1162,26 @@ def forward(self, inputs): "erf", "exp", "floor", + "hardshrink", + "hardtanh", "log", "log2", "log10", + "reciprocal", "relu", + "relu6", "round", "rsqrt", + "selu", "sigmoid", "sign", "sin", "sinh", + "softplus", + "softsign", "sqrt", + "square", + "swish", "tan", "tanh", ] @@ -956,6 +1189,9 @@ def forward(self, inputs): for input_shape in input_shapes: input_data = paddle.rand(input_shape, dtype="float32") for api_name in api_list: + if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]: + # avoid illegal input, all elements should be positive + input_data = paddle.uniform(input_shape, min=0.01, max=0.99) verify_model(MathAPI(api_name), input_data=input_data)